diff --git a/app.py b/app.py
index 8bf6bc3..1dc4614 100644
--- a/app.py
+++ b/app.py
@@ -83,9 +83,11 @@ app.config['PERMANENT_SESSION_LIFETIME'] = timedelta(hours=2)
from views.page import page
from views.user import user
from views.spider_control import spider_bp
+from views.workflow_api import workflow_bp
app.register_blueprint(page.pb)
app.register_blueprint(user.ub)
app.register_blueprint(spider_bp)
+app.register_blueprint(workflow_bp) # 注册工作流蓝图
# 首页路由
@app.route('/')
diff --git a/createTables.sql b/createTables.sql
index 044bb97..508ca12 100644
--- a/createTables.sql
+++ b/createTables.sql
@@ -46,4 +46,52 @@ CREATE TABLE `user` (
`id` int(11) NOT NULL AUTO_INCREMENT,
`createTime` varchar(255) DEFAULT NULL,
PRIMARY KEY (`id`)
-) ENGINE=InnoDB AUTO_INCREMENT=4 DEFAULT CHARSET=utf8;
\ No newline at end of file
+) ENGINE=InnoDB AUTO_INCREMENT=4 DEFAULT CHARSET=utf8;
+
+-- 爬虫模板表
+CREATE TABLE IF NOT EXISTS `crawler_templates` (
+ `id` VARCHAR(64) NOT NULL COMMENT '模板ID',
+ `name` VARCHAR(64) NOT NULL COMMENT '模板名称',
+ `description` VARCHAR(255) NULL COMMENT '模板描述',
+ `icon` VARCHAR(32) NULL COMMENT '图标',
+ `config` JSON NOT NULL COMMENT '配置JSON',
+ `created_at` DATETIME NOT NULL COMMENT '创建时间',
+ `updated_at` DATETIME NOT NULL COMMENT '更新时间',
+ `deleted` TINYINT(1) NOT NULL DEFAULT 0 COMMENT '是否删除',
+ PRIMARY KEY (`id`),
+ INDEX `idx_crawler_templates_name` (`name`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='爬虫配置模板表';
+
+-- 分析流程模板表
+CREATE TABLE IF NOT EXISTS `analysis_templates` (
+ `id` VARCHAR(64) NOT NULL COMMENT '模板ID',
+ `name` VARCHAR(64) NOT NULL COMMENT '模板名称',
+ `description` VARCHAR(255) NULL COMMENT '模板描述',
+ `icon` VARCHAR(32) NULL COMMENT '图标',
+ `components` JSON NOT NULL COMMENT '组件JSON',
+ `connections` JSON NOT NULL COMMENT '连接JSON',
+ `created_at` DATETIME NOT NULL COMMENT '创建时间',
+ `updated_at` DATETIME NOT NULL COMMENT '更新时间',
+ `deleted` TINYINT(1) NOT NULL DEFAULT 0 COMMENT '是否删除',
+ PRIMARY KEY (`id`),
+ INDEX `idx_analysis_templates_name` (`name`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='分析流程模板表';
+
+-- 工作流执行任务表
+CREATE TABLE IF NOT EXISTS `workflow_tasks` (
+ `id` VARCHAR(64) NOT NULL COMMENT '任务ID',
+ `template_id` VARCHAR(64) NULL COMMENT '关联模板ID',
+ `type` VARCHAR(32) NOT NULL COMMENT '任务类型:crawler/analysis',
+ `status` VARCHAR(16) NOT NULL COMMENT '任务状态:pending/running/completed/failed',
+ `progress` INT(11) NOT NULL DEFAULT 0 COMMENT '进度百分比',
+ `config` JSON NOT NULL COMMENT '任务配置',
+ `result` JSON NULL COMMENT '执行结果',
+ `error` TEXT NULL COMMENT '错误信息',
+ `started_at` DATETIME NULL COMMENT '开始时间',
+ `completed_at` DATETIME NULL COMMENT '完成时间',
+ `created_at` DATETIME NOT NULL COMMENT '创建时间',
+ `updated_at` DATETIME NOT NULL COMMENT '更新时间',
+ PRIMARY KEY (`id`),
+ INDEX `idx_workflow_tasks_type_status` (`type`, `status`),
+ INDEX `idx_workflow_tasks_template` (`template_id`)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='工作流执行任务表';
\ No newline at end of file
diff --git a/static/js/workflow_editor.js b/static/js/workflow_editor.js
new file mode 100644
index 0000000..a5ec74b
--- /dev/null
+++ b/static/js/workflow_editor.js
@@ -0,0 +1,1065 @@
+document.addEventListener('DOMContentLoaded', function() {
+ // 工作流编辑器的主要元素
+ const workflowCanvas = document.getElementById('workflowCanvas');
+ const connectionsSvg = document.getElementById('connectionsSvg');
+
+ // 工作流数据对象
+ let workflowData = {
+ metadata: {
+ name: '新建工作流',
+ description: '',
+ created: new Date().toISOString(),
+ modified: new Date().toISOString()
+ },
+ nodes: [],
+ connections: []
+ };
+
+ // 拖拽相关变量
+ let isDragging = false;
+ let dragTarget = null;
+ let dragOffset = { x: 0, y: 0 };
+
+ // 连接相关变量
+ let isConnecting = false;
+ let connectionStart = null;
+ let connectionPreviewPath = null;
+
+ // 设置编辑器网格背景
+ setEditorBackground();
+
+ // 初始化组件面板拖拽
+ initializeComponentDrag();
+
+ function setEditorBackground() {
+ workflowCanvas.style.backgroundSize = '20px 20px';
+ workflowCanvas.style.backgroundImage = `
+ linear-gradient(to right, #f0f0f0 1px, transparent 1px),
+ linear-gradient(to bottom, #f0f0f0 1px, transparent 1px)
+ `;
+ }
+
+ function initializeComponentDrag() {
+ const components = document.querySelectorAll('.component-item');
+ components.forEach(component => {
+ component.setAttribute('draggable', 'true');
+ component.addEventListener('dragstart', function(e) {
+ e.dataTransfer.setData('componentType', this.dataset.type);
+ e.dataTransfer.setData('componentSubtype', this.dataset.subtype);
+ });
+ });
+ }
+
+ // 创建模板卡片
+ function createTemplateCard(template) {
+ const div = document.createElement('div');
+ div.className = 'template-item';
+
+ div.innerHTML = `
+
+
+ ${template.name}
+
+ ${template.description || '无描述'}
+
+
+
+ `;
+
+ // 添加加载模板的事件
+ const loadBtn = div.querySelector('.load-template-btn');
+ loadBtn.addEventListener('click', function() {
+ loadWorkflow(template.id);
+ });
+
+ return div;
+ }
+
+ function loadWorkflow(templateId) {
+ // 加载特定工作流模板的逻辑
+ fetch(`/api/workflow/${templateId}`)
+ .then(response => response.json())
+ .then(data => {
+ if (data.success) {
+ clearWorkflow();
+ renderWorkflow(data.workflow);
+ } else {
+ alert('加载工作流失败: ' + data.error);
+ }
+ })
+ .catch(error => {
+ console.error('加载工作流出错:', error);
+ alert('加载工作流时发生错误');
+ });
+ }
+
+ function clearWorkflow() {
+ // 清除画布上的所有节点和连接
+ workflowCanvas.querySelectorAll('.workflow-node').forEach(node => {
+ node.parentNode.removeChild(node);
+ });
+
+ connectionsSvg.querySelectorAll('.connection-path').forEach(path => {
+ path.parentNode.removeChild(path);
+ });
+
+ // 清空数据
+ workflowData.nodes = [];
+ workflowData.connections = [];
+ }
+
+ function renderWorkflow(workflowToRender) {
+ // 设置工作流元数据
+ workflowData.metadata = workflowToRender.metadata || {
+ name: '未命名工作流',
+ description: '',
+ created: new Date().toISOString(),
+ modified: new Date().toISOString()
+ };
+
+ // 渲染节点
+ if (workflowToRender.nodes && Array.isArray(workflowToRender.nodes)) {
+ workflowToRender.nodes.forEach(node => {
+ const nodeElement = createNodeFromData(node);
+ if (nodeElement) {
+ setupNodeEvents(nodeElement, node);
+ }
+ });
+ }
+
+ // 渲染连接
+ if (workflowToRender.connections && Array.isArray(workflowToRender.connections)) {
+ workflowToRender.connections.forEach(conn => {
+ workflowData.connections.push(conn);
+ drawConnection(conn.sourceId, conn.targetId, conn.id);
+ });
+ }
+ }
+
+ function createNodeFromData(nodeData) {
+ // 从数据创建节点DOM元素
+ const nodeElement = document.createElement('div');
+ nodeElement.className = 'workflow-node';
+ nodeElement.id = nodeData.id;
+ nodeElement.style.left = nodeData.x + 'px';
+ nodeElement.style.top = nodeData.y + 'px';
+
+ // 构建节点内容
+ nodeElement.innerHTML = `
+
+
+
${nodeData.config ? '已配置' : '点击配置参数'}
+
+
+
+
+
+ `;
+
+ workflowCanvas.appendChild(nodeElement);
+
+ // 添加到节点数据
+ workflowData.nodes.push(nodeData);
+
+ return nodeElement;
+ }
+
+ // ====== 运行工作流 ======
+ document.getElementById('runWorkflowBtn').addEventListener('click', function() {
+ $('#runWorkflowModal').modal('show');
+ });
+
+ document.getElementById('confirmRunBtn').addEventListener('click', function() {
+ const shouldSave = document.getElementById('saveBeforeRun').checked;
+
+ if (shouldSave) {
+ // 如果选择了先保存
+ workflowData.metadata.modified = new Date().toISOString();
+ saveWorkflow(workflowData);
+ }
+
+ // 关闭确认对话框
+ $('#runWorkflowModal').modal('hide');
+
+ // 提交工作流执行
+ runWorkflow(workflowData);
+ });
+
+ function runWorkflow(workflow) {
+ // 发送工作流到服务器执行
+ fetch('/api/workflow/run', {
+ method: 'POST',
+ headers: {
+ 'Content-Type': 'application/json'
+ },
+ body: JSON.stringify(workflow)
+ })
+ .then(response => response.json())
+ .then(data => {
+ if (data.success) {
+ // 显示任务状态监控
+ showTaskStatus(data.taskId);
+ } else {
+ alert('运行工作流失败: ' + data.error);
+ }
+ })
+ .catch(error => {
+ console.error('运行工作流出错:', error);
+ alert('运行工作流时发生错误,请重试');
+ });
+ }
+
+ // ====== 任务状态监控 ======
+ let taskStatusInterval = null;
+
+ function showTaskStatus(taskId) {
+ // 显示任务状态模态框
+ document.getElementById('taskIdDisplay').textContent = taskId;
+ document.getElementById('taskStatusDisplay').textContent = '运行中';
+ document.getElementById('taskStartTimeDisplay').textContent = new Date().toLocaleString();
+ document.getElementById('taskCompleteTimeDisplay').textContent = '-';
+ document.getElementById('taskProgressBar').style.width = '0%';
+ document.getElementById('taskResultPreview').innerHTML = '任务运行中,请稍候...
';
+
+ $('#taskStatusModal').modal('show');
+
+ // 开始定期检查任务状态
+ if (taskStatusInterval) {
+ clearInterval(taskStatusInterval);
+ }
+
+ pollTaskStatus(taskId);
+ taskStatusInterval = setInterval(() => pollTaskStatus(taskId), 3000);
+ }
+
+ function pollTaskStatus(taskId) {
+ fetch(`/api/task/${taskId}/status`)
+ .then(response => response.json())
+ .then(data => {
+ updateTaskStatusDisplay(data);
+
+ // 如果任务已完成或失败,停止轮询
+ if (data.status === 'completed' || data.status === 'failed') {
+ if (taskStatusInterval) {
+ clearInterval(taskStatusInterval);
+ taskStatusInterval = null;
+ }
+ }
+ })
+ .catch(error => {
+ console.error('获取任务状态出错:', error);
+ });
+ }
+
+ function updateTaskStatusDisplay(taskData) {
+ const statusDisplay = document.getElementById('taskStatusDisplay');
+ const progressBar = document.getElementById('taskProgressBar');
+ const resultPreview = document.getElementById('taskResultPreview');
+
+ statusDisplay.textContent = getStatusText(taskData.status);
+
+ // 更新进度条
+ progressBar.style.width = `${taskData.progress || 0}%`;
+
+ // 根据状态设置进度条颜色
+ progressBar.className = 'progress-bar';
+ if (taskData.status === 'completed') {
+ progressBar.classList.add('bg-success');
+ document.getElementById('taskCompleteTimeDisplay').textContent = new Date().toLocaleString();
+ } else if (taskData.status === 'failed') {
+ progressBar.classList.add('bg-danger');
+ document.getElementById('taskCompleteTimeDisplay').textContent = new Date().toLocaleString();
+ } else {
+ progressBar.classList.add('bg-primary');
+ }
+
+ // 显示结果预览
+ if (taskData.status === 'completed' && taskData.resultPreview) {
+ resultPreview.innerHTML = generateResultPreview(taskData.resultPreview);
+ } else if (taskData.status === 'failed' && taskData.error) {
+ resultPreview.innerHTML = `${taskData.error}
`;
+ }
+ }
+
+ function getStatusText(status) {
+ switch (status) {
+ case 'pending': return '排队中';
+ case 'running': return '运行中';
+ case 'completed': return '已完成';
+ case 'failed': return '失败';
+ default: return status;
+ }
+ }
+
+ function generateResultPreview(resultData) {
+ if (!resultData) return '无可用预览
';
+
+ let html = '';
+
+ if (resultData.type === 'text') {
+ html = `${resultData.content}`;
+ } else if (resultData.type === 'table') {
+ html = '';
+
+ // 表头
+ if (resultData.headers && resultData.headers.length) {
+ html += '';
+ resultData.headers.forEach(header => {
+ html += `| ${header} | `;
+ });
+ html += '
';
+ }
+
+ // 表内容
+ if (resultData.rows && resultData.rows.length) {
+ html += '';
+ resultData.rows.slice(0, 5).forEach(row => {
+ html += '';
+ row.forEach(cell => {
+ html += `| ${cell} | `;
+ });
+ html += '
';
+ });
+ html += '';
+ }
+
+ html += '
';
+
+ if (resultData.rows && resultData.rows.length > 5) {
+ html += `
显示前5行,共${resultData.rows.length}行
`;
+ }
+ html += '
';
+ } else if (resultData.type === 'chart') {
+ html = '';
+ }
+
+ return html;
+ }
+
+ document.getElementById('cancelTaskBtn').addEventListener('click', function() {
+ const taskId = document.getElementById('taskIdDisplay').textContent;
+
+ if (!taskId) return;
+
+ // 发送取消任务请求
+ fetch(`/api/task/${taskId}/cancel`, { method: 'POST' })
+ .then(response => response.json())
+ .then(data => {
+ if (data.success) {
+ alert('任务已取消');
+ if (taskStatusInterval) {
+ clearInterval(taskStatusInterval);
+ taskStatusInterval = null;
+ }
+
+ document.getElementById('taskStatusDisplay').textContent = '已取消';
+ } else {
+ alert('取消任务失败: ' + data.error);
+ }
+ })
+ .catch(error => {
+ console.error('取消任务出错:', error);
+ alert('取消任务时发生错误');
+ });
+ });
+
+ document.getElementById('viewResultBtn').addEventListener('click', function() {
+ const taskId = document.getElementById('taskIdDisplay').textContent;
+ if (!taskId) return;
+
+ // 跳转到结果页面
+ window.open(`/result/${taskId}`, '_blank');
+ });
+
+ // ====== 辅助函数 ======
+ function getComponentTypeLabel(type) {
+ const typeLabels = {
+ 'data_source': '数据源',
+ 'preprocessing': '数据处理',
+ 'model': '模型分析',
+ 'visualization': '可视化'
+ };
+
+ return typeLabels[type] || type;
+ }
+
+ function getDefaultConfig(type, subtype) {
+ // 根据组件类型返回默认配置
+ const defaults = {
+ 'data_source': {
+ 'database': { connectionString: '', query: '' },
+ 'file': { filePath: '', format: 'csv' },
+ 'crawler': { url: '', depth: 1, keywords: '' }
+ },
+ 'preprocessing': {
+ 'filter': { field: '', operator: 'contains', value: '' },
+ 'sort': { field: '', order: 'asc' },
+ 'aggregate': { groupBy: '', function: 'count' }
+ },
+ 'model': {
+ 'sentiment': { language: 'zh', algorithm: 'bayes' },
+ 'topic': { numTopics: 5, algorithm: 'lda' },
+ 'keywords': { topk: 10, algorithm: 'tfidf' },
+ 'summarize': { ratio: 0.2, algorithm: 'extractive' }
+ },
+ 'visualization': {
+ 'chart': { type: 'bar', title: '', xField: '', yField: '' },
+ 'table': { fields: [], pageSize: 10 },
+ 'wordcloud': { maxWords: 100, colorScheme: 'default' }
+ }
+ };
+
+ return defaults[type] && defaults[type][subtype] ? defaults[type][subtype] : {};
+ }
+
+ function getComponentConfigs(type, subtype) {
+ // 返回特定组件类型的配置选项
+ const configs = {
+ 'data_source': {
+ 'database': [
+ { id: 'connectionString', label: '连接字符串', type: 'text' },
+ { id: 'query', label: 'SQL查询', type: 'textarea' },
+ { id: 'limit', label: '结果限制', type: 'number' }
+ ],
+ 'file': [
+ { id: 'filePath', label: '文件路径', type: 'text' },
+ { id: 'format', label: '文件格式', type: 'select', options: [
+ { value: 'csv', label: 'CSV' },
+ { value: 'excel', label: 'Excel' },
+ { value: 'json', label: 'JSON' },
+ { value: 'txt', label: '文本文件' }
+ ] }
+ ],
+ 'crawler': [
+ { id: 'url', label: '起始URL', type: 'text' },
+ { id: 'depth', label: '爬取深度', type: 'number' },
+ { id: 'keywords', label: '关键词', type: 'text' },
+ { id: 'maxItems', label: '最大爬取数量', type: 'number' }
+ ]
+ },
+ 'preprocessing': {
+ 'filter': [
+ { id: 'field', label: '字段名', type: 'text' },
+ { id: 'operator', label: '操作符', type: 'select', options: [
+ { value: 'equals', label: '等于' },
+ { value: 'contains', label: '包含' },
+ { value: 'startsWith', label: '开头是' },
+ { value: 'endsWith', label: '结尾是' },
+ { value: 'greaterThan', label: '大于' },
+ { value: 'lessThan', label: '小于' }
+ ] },
+ { id: 'value', label: '值', type: 'text' }
+ ],
+ 'sort': [
+ { id: 'field', label: '排序字段', type: 'text' },
+ { id: 'order', label: '排序方向', type: 'select', options: [
+ { value: 'asc', label: '升序' },
+ { value: 'desc', label: '降序' }
+ ] }
+ ],
+ 'aggregate': [
+ { id: 'groupBy', label: '分组字段', type: 'text' },
+ { id: 'function', label: '聚合函数', type: 'select', options: [
+ { value: 'count', label: '计数' },
+ { value: 'sum', label: '求和' },
+ { value: 'avg', label: '平均值' },
+ { value: 'min', label: '最小值' },
+ { value: 'max', label: '最大值' }
+ ] },
+ { id: 'valueField', label: '值字段', type: 'text' }
+ ]
+ },
+ 'model': {
+ 'sentiment': [
+ { id: 'language', label: '语言', type: 'select', options: [
+ { value: 'zh', label: '中文' },
+ { value: 'en', label: '英文' }
+ ] },
+ { id: 'algorithm', label: '算法', type: 'select', options: [
+ { value: 'bayes', label: '朴素贝叶斯' },
+ { value: 'svm', label: '支持向量机' },
+ { value: 'bert', label: 'BERT' }
+ ] },
+ { id: 'textField', label: '文本字段', type: 'text' }
+ ],
+ 'topic': [
+ { id: 'numTopics', label: '主题数量', type: 'number' },
+ { id: 'algorithm', label: '算法', type: 'select', options: [
+ { value: 'lda', label: 'LDA' },
+ { value: 'nmf', label: 'NMF' }
+ ] },
+ { id: 'textField', label: '文本字段', type: 'text' }
+ ],
+ 'keywords': [
+ { id: 'topk', label: '关键词数量', type: 'number' },
+ { id: 'algorithm', label: '算法', type: 'select', options: [
+ { value: 'tfidf', label: 'TF-IDF' },
+ { value: 'textrank', label: 'TextRank' }
+ ] },
+ { id: 'textField', label: '文本字段', type: 'text' }
+ ],
+ 'summarize': [
+ { id: 'ratio', label: '摘要比例', type: 'number' },
+ { id: 'algorithm', label: '算法', type: 'select', options: [
+ { value: 'extractive', label: '抽取式摘要' },
+ { value: 'abstractive', label: '生成式摘要' }
+ ] },
+ { id: 'textField', label: '文本字段', type: 'text' }
+ ]
+ },
+ 'visualization': {
+ 'chart': [
+ { id: 'type', label: '图表类型', type: 'select', options: [
+ { value: 'bar', label: '柱状图' },
+ { value: 'line', label: '折线图' },
+ { value: 'pie', label: '饼图' },
+ { value: 'scatter', label: '散点图' }
+ ] },
+ { id: 'title', label: '图表标题', type: 'text' },
+ { id: 'xField', label: 'X轴字段', type: 'text' },
+ { id: 'yField', label: 'Y轴字段', type: 'text' },
+ { id: 'colorField', label: '颜色字段', type: 'text' }
+ ],
+ 'table': [
+ { id: 'fields', label: '显示字段(逗号分隔)', type: 'text' },
+ { id: 'pageSize', label: '每页记录数', type: 'number' },
+ { id: 'sortable', label: '允许排序', type: 'checkbox' }
+ ],
+ 'wordcloud': [
+ { id: 'textField', label: '文本字段', type: 'text' },
+ { id: 'maxWords', label: '最大词数', type: 'number' },
+ { id: 'colorScheme', label: '配色方案', type: 'select', options: [
+ { value: 'default', label: '默认' },
+ { value: 'warm', label: '暖色调' },
+ { value: 'cool', label: '冷色调' },
+ { value: 'rainbow', label: '彩虹色' }
+ ] }
+ ]
+ }
+ };
+
+ return configs[type] && configs[type][subtype] ? configs[type][subtype] : [];
+ }
+
+ // 初始加载示例模板
+ showSampleTemplates();
+
+ function saveWorkflow(workflowData) {
+ // 保存工作流到服务器
+ fetch('/api/workflow/save', {
+ method: 'POST',
+ headers: {
+ 'Content-Type': 'application/json'
+ },
+ body: JSON.stringify(workflowData)
+ })
+ .then(response => response.json())
+ .then(data => {
+ if (data.success) {
+ alert('工作流保存成功');
+ } else {
+ alert('保存工作流失败: ' + data.error);
+ }
+ })
+ .catch(error => {
+ console.error('保存工作流出错:', error);
+ alert('保存工作流时发生错误');
+ });
+ }
+
+ // 显示示例模板(修复CORS错误)
+ function showSampleTemplates() {
+ // 使用示例模板数据,避免直接从文件系统加载API
+ const sampleTemplates = [
+ {
+ id: 'template_1',
+ name: '微博热搜分析模板',
+ description: '爬取微博热搜榜数据,分析热点话题和情感倾向',
+ icon: 'fire'
+ },
+ {
+ id: 'template_2',
+ name: '用户评论情感分析',
+ description: '分析用户评论的情感倾向,生成情感分布图表',
+ icon: 'heart'
+ },
+ {
+ id: 'template_3',
+ name: '话题趋势监测',
+ description: '监测特定话题的讨论热度变化及关键词提取',
+ icon: 'chart-line'
+ }
+ ];
+
+ try {
+ const container = document.getElementById('analysisTemplatesList');
+ if(container) {
+ container.innerHTML = '';
+ sampleTemplates.forEach(template => {
+ const templateDiv = createTemplateCard(template);
+ container.appendChild(templateDiv);
+ });
+ } else {
+ // 尝试其他容器
+ const alternativeContainer = document.getElementById('templateList') ||
+ document.getElementById('crawlerTemplatesList');
+ if (alternativeContainer) {
+ alternativeContainer.innerHTML = '';
+ sampleTemplates.forEach(template => {
+ const templateDiv = createTemplateCard(template);
+ alternativeContainer.appendChild(templateDiv);
+ });
+ } else {
+ console.warn('未找到合适的模板容器');
+ }
+ }
+ } catch (error) {
+ console.error('加载模板出错:', error);
+ }
+ }
+
+ // 模板拖放功能
+ workflowCanvas.addEventListener('dragover', function(e) {
+ e.preventDefault();
+ e.dataTransfer.dropEffect = 'copy';
+ });
+
+ workflowCanvas.addEventListener('drop', function(e) {
+ e.preventDefault();
+ const componentType = e.dataTransfer.getData('componentType');
+ const componentSubtype = e.dataTransfer.getData('componentSubtype');
+
+ if (componentType && componentSubtype) {
+ const rect = workflowCanvas.getBoundingClientRect();
+ const x = e.clientX - rect.left;
+ const y = e.clientY - rect.top;
+
+ addNode(componentType, componentSubtype, x, y);
+ }
+ });
+
+ // 添加其他初始化代码
+ document.addEventListener('DOMContentLoaded', function() {
+ initializeWorkflowEditor();
+ setupEventListeners();
+ showSampleTemplates();
+ });
+
+ function initializeWorkflowEditor() {
+ // 初始化编辑器的基本设置
+ workflowData = {
+ metadata: {
+ name: '新建工作流',
+ description: '',
+ created: new Date().toISOString(),
+ modified: new Date().toISOString()
+ },
+ nodes: [],
+ connections: []
+ };
+ }
+
+ function setupEventListeners() {
+ // 设置各种事件监听器
+ document.getElementById('saveWorkflowBtn').addEventListener('click', function() {
+ workflowData.metadata.modified = new Date().toISOString();
+ saveWorkflow(workflowData);
+ });
+ }
+
+ // 添加节点的函数
+ function addNode(componentType, componentSubtype, x, y) {
+ const nodeId = 'node_' + Date.now();
+ const nodeData = {
+ id: nodeId,
+ type: componentType,
+ subtype: componentSubtype,
+ title: getComponentTypeLabel(componentType) + '-' + componentSubtype,
+ x: x,
+ y: y,
+ config: getDefaultConfig(componentType, componentSubtype)
+ };
+
+ const nodeElement = createNodeFromData(nodeData);
+ setupNodeEvents(nodeElement, nodeData);
+ }
+
+ // 设置节点事件
+ function setupNodeEvents(nodeElement, nodeData) {
+ // 节点拖动事件
+ nodeElement.addEventListener('mousedown', function(e) {
+ if (e.target.closest('.port') || e.target.closest('.delete-node-btn')) {
+ return; // 如果点击的是端口或删除按钮,不处理拖动
+ }
+
+ isDragging = true;
+ dragTarget = nodeElement;
+ const rect = nodeElement.getBoundingClientRect();
+ dragOffset = {
+ x: e.clientX - rect.left,
+ y: e.clientY - rect.top
+ };
+
+ nodeElement.style.zIndex = '100';
+ });
+
+ // 删除节点
+ const deleteBtn = nodeElement.querySelector('.delete-node-btn');
+ deleteBtn.addEventListener('click', function() {
+ deleteNode(nodeData.id);
+ });
+
+ // 节点配置
+ nodeElement.addEventListener('click', function(e) {
+ if (!e.target.closest('.port') && !e.target.closest('.delete-node-btn')) {
+ openNodeConfig(nodeData);
+ }
+ });
+
+ // 连接处理
+ const ports = nodeElement.querySelectorAll('.port');
+ ports.forEach(port => {
+ port.addEventListener('mousedown', function(e) {
+ e.stopPropagation();
+ if (port.dataset.portType === 'output') {
+ startConnection(nodeData.id, e);
+ }
+ });
+
+ port.addEventListener('mouseup', function() {
+ if (isConnecting && connectionStart && connectionStart.id !== nodeData.id && port.dataset.portType === 'input') {
+ completeConnection(connectionStart.id, nodeData.id);
+ }
+ });
+ });
+ }
+
+ // 删除节点
+ function deleteNode(nodeId) {
+ const node = document.getElementById(nodeId);
+ if (node) {
+ node.parentNode.removeChild(node);
+ }
+
+ // 删除相关连接
+ workflowData.connections = workflowData.connections.filter(conn => {
+ if (conn.sourceId === nodeId || conn.targetId === nodeId) {
+ const path = document.getElementById('connection_' + conn.id);
+ if (path) {
+ path.parentNode.removeChild(path);
+ }
+ return false;
+ }
+ return true;
+ });
+
+ // 从数据中删除节点
+ workflowData.nodes = workflowData.nodes.filter(node => node.id !== nodeId);
+ }
+
+ // 处理全局鼠标事件
+ document.addEventListener('mousemove', function(e) {
+ if (isDragging && dragTarget) {
+ const x = e.clientX - dragOffset.x;
+ const y = e.clientY - dragOffset.y;
+
+ dragTarget.style.left = x + 'px';
+ dragTarget.style.top = y + 'px';
+
+ // 更新节点数据
+ const nodeId = dragTarget.id;
+ const node = workflowData.nodes.find(n => n.id === nodeId);
+ if (node) {
+ node.x = x;
+ node.y = y;
+ }
+
+ // 更新连接
+ updateNodeConnections(nodeId);
+ }
+
+ // 处理连接预览
+ if (isConnecting && connectionStart) {
+ updateConnectionPreview(e.clientX, e.clientY);
+ }
+ });
+
+ document.addEventListener('mouseup', function() {
+ if (isDragging && dragTarget) {
+ dragTarget.style.zIndex = '10';
+ isDragging = false;
+ dragTarget = null;
+ }
+
+ if (isConnecting) {
+ cancelConnection();
+ }
+ });
+
+ // 初始化侧边栏切换
+ const componentsTabBtn = document.getElementById('componentsTabBtn');
+ const templatesTabBtn = document.getElementById('templatesTabBtn');
+
+ if (componentsTabBtn) {
+ componentsTabBtn.addEventListener('click', function() {
+ document.getElementById('componentsPanel').style.display = 'block';
+ document.getElementById('templatesPanel').style.display = 'none';
+ this.classList.add('active');
+ templatesTabBtn.classList.remove('active');
+ });
+ }
+
+ if (templatesTabBtn) {
+ templatesTabBtn.addEventListener('click', function() {
+ document.getElementById('componentsPanel').style.display = 'none';
+ document.getElementById('templatesPanel').style.display = 'block';
+ this.classList.add('active');
+ componentsTabBtn.classList.remove('active');
+ });
+ }
+
+ // 初始加载示例模板
+ showSampleTemplates();
+
+ // 绘制连接
+ function drawConnection(sourceId, targetId, connectionId) {
+ const sourceNode = document.getElementById(sourceId);
+ const targetNode = document.getElementById(targetId);
+
+ if (!sourceNode || !targetNode) {
+ console.error('连接节点不存在:', sourceId, targetId);
+ return null;
+ }
+
+ const sourcePort = sourceNode.querySelector('.port-out');
+ const targetPort = targetNode.querySelector('.port-in');
+
+ const sourceRect = sourcePort.getBoundingClientRect();
+ const targetRect = targetPort.getBoundingClientRect();
+ const canvasRect = workflowCanvas.getBoundingClientRect();
+
+ const start = {
+ x: sourceRect.left + sourceRect.width/2 - canvasRect.left,
+ y: sourceRect.top + sourceRect.height/2 - canvasRect.top
+ };
+
+ const end = {
+ x: targetRect.left + targetRect.width/2 - canvasRect.left,
+ y: targetRect.top + targetRect.height/2 - canvasRect.top
+ };
+
+ // 创建连接路径
+ const path = document.createElementNS('http://www.w3.org/2000/svg', 'path');
+ path.setAttribute('class', 'connection-path');
+ path.setAttribute('id', 'connection_' + (connectionId || `${sourceId}_${targetId}`));
+
+ // 绘制贝塞尔曲线
+ const dx = Math.abs(end.x - start.x) * 0.5;
+ const pathData = `M ${start.x},${start.y} C ${start.x + dx},${start.y} ${end.x - dx},${end.y} ${end.x},${end.y}`;
+ path.setAttribute('d', pathData);
+
+ connectionsSvg.appendChild(path);
+ return path;
+ }
+
+ // 更新节点连接
+ function updateNodeConnections(nodeId) {
+ workflowData.connections.forEach(conn => {
+ if (conn.sourceId === nodeId || conn.targetId === nodeId) {
+ const path = document.getElementById('connection_' + conn.id);
+ if (path) {
+ path.parentNode.removeChild(path);
+ }
+ drawConnection(conn.sourceId, conn.targetId, conn.id);
+ }
+ });
+ }
+
+ // 开始创建连接
+ function startConnection(nodeId, e) {
+ isConnecting = true;
+ connectionStart = { id: nodeId, event: e };
+
+ // 创建预览连接线
+ connectionPreviewPath = document.createElementNS('http://www.w3.org/2000/svg', 'path');
+ connectionPreviewPath.setAttribute('class', 'connection-path');
+ connectionPreviewPath.style.strokeDasharray = '5,5';
+ connectionPreviewPath.style.opacity = '0.6';
+ connectionsSvg.appendChild(connectionPreviewPath);
+
+ updateConnectionPreview(e.clientX, e.clientY);
+ }
+
+ // 更新连接预览
+ function updateConnectionPreview(clientX, clientY) {
+ if (!connectionStart || !connectionPreviewPath) return;
+
+ const sourceNode = document.getElementById(connectionStart.id);
+ if (!sourceNode) return;
+
+ const sourcePort = sourceNode.querySelector('.port-out');
+ const sourceRect = sourcePort.getBoundingClientRect();
+ const canvasRect = workflowCanvas.getBoundingClientRect();
+
+ const start = {
+ x: sourceRect.left + sourceRect.width/2 - canvasRect.left,
+ y: sourceRect.top + sourceRect.height/2 - canvasRect.top
+ };
+
+ const end = {
+ x: clientX - canvasRect.left,
+ y: clientY - canvasRect.top
+ };
+
+ // 绘制预览连接
+ const dx = Math.abs(end.x - start.x) * 0.5;
+ const pathData = `M ${start.x},${start.y} C ${start.x + dx},${start.y} ${end.x - dx},${end.y} ${end.x},${end.y}`;
+ connectionPreviewPath.setAttribute('d', pathData);
+ }
+
+ // 完成连接
+ function completeConnection(sourceId, targetId) {
+ // 检查连接是否已存在
+ const connectionExists = workflowData.connections.some(conn =>
+ conn.sourceId === sourceId && conn.targetId === targetId);
+
+ if (connectionExists) {
+ cancelConnection();
+ return;
+ }
+
+ // 生成连接ID
+ const connectionId = `conn_${Date.now()}`;
+
+ // 添加到数据中
+ workflowData.connections.push({
+ id: connectionId,
+ sourceId: sourceId,
+ targetId: targetId
+ });
+
+ // 绘制最终连接
+ drawConnection(sourceId, targetId, connectionId);
+
+ // 清理预览状态
+ cancelConnection();
+ }
+
+ // 取消连接操作
+ function cancelConnection() {
+ if (connectionPreviewPath && connectionPreviewPath.parentNode) {
+ connectionPreviewPath.parentNode.removeChild(connectionPreviewPath);
+ }
+
+ isConnecting = false;
+ connectionStart = null;
+ connectionPreviewPath = null;
+ }
+
+ // 打开节点配置面板
+ function openNodeConfig(nodeData) {
+ const propertiesPanel = document.getElementById('propertiesPanel');
+ const propertiesContent = document.getElementById('propertiesContent');
+
+ // 显示面板
+ propertiesPanel.classList.add('open');
+
+ // 生成配置表单
+ const configOptions = getComponentConfigs(nodeData.type, nodeData.subtype);
+ let formHtml = `
+ ${nodeData.title} 配置
+
+ `;
+
+ propertiesContent.innerHTML = formHtml;
+
+ // 保存配置事件
+ document.getElementById('saveConfigBtn').addEventListener('click', function() {
+ const form = document.getElementById('nodeConfigForm');
+ const formData = new FormData(form);
+ const config = {};
+
+ // 构建配置对象
+ configOptions.forEach(option => {
+ if (option.type === 'checkbox') {
+ config[option.id] = document.getElementById(option.id).checked;
+ } else {
+ config[option.id] = formData.get(option.id);
+ }
+ });
+
+ // 更新节点配置
+ const node = workflowData.nodes.find(n => n.id === nodeData.id);
+ if (node) {
+ node.config = config;
+
+ // 更新节点显示
+ const nodeElement = document.getElementById(nodeData.id);
+ if (nodeElement) {
+ const descElement = nodeElement.querySelector('.node-description');
+ if (descElement) {
+ descElement.textContent = '已配置';
+ }
+ }
+ }
+
+ // 关闭面板
+ closePropertiesPanel();
+ });
+ }
+
+ // 关闭属性面板
+ function closePropertiesPanel() {
+ document.getElementById('propertiesPanel').classList.remove('open');
+ }
+
+ // 绑定关闭属性面板的事件
+ document.getElementById('closePropertiesBtn').addEventListener('click', closePropertiesPanel);
+});
diff --git a/templates/workflow_editor.html b/templates/workflow_editor.html
new file mode 100644
index 0000000..13a60f2
--- /dev/null
+++ b/templates/workflow_editor.html
@@ -0,0 +1,519 @@
+
+
+
+
+
+ 工作流编辑器 - 微博舆情分析系统
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
确认要运行当前工作流吗?
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
状态信息
+
+
任务ID: -
+
状态: -
+
开始时间: -
+
完成时间: -
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/utils/cache_manager.py b/utils/cache_manager.py
index 4955a5d..e8158f0 100644
--- a/utils/cache_manager.py
+++ b/utils/cache_manager.py
@@ -1,116 +1,301 @@
import json
import os
import time
+import shutil
from datetime import datetime, timedelta
import threading
import queue
+from collections import OrderedDict
+import pickle
+import hashlib
+import logging
-class PredictionCache:
+logger = logging.getLogger('cache_manager')
+logger.setLevel(logging.INFO)
+
+class LRUCache:
+ """实现LRU (Least Recently Used) 缓存策略"""
+
+ def __init__(self, capacity):
+ self.cache = OrderedDict()
+ self.capacity = capacity
+
+ def get(self, key):
+ if key not in self.cache:
+ return None
+ # 访问元素时,将其移至末尾,表示最近使用
+ self.cache.move_to_end(key)
+ return self.cache[key]
+
+ def put(self, key, value):
+ # 如果键已存在,更新值并将其移至末尾
+ if key in self.cache:
+ self.cache[key] = value
+ self.cache.move_to_end(key)
+ return
+
+ # 如果缓存已满,删除最久未使用的项(OrderedDict 的首项)
+ if len(self.cache) >= self.capacity:
+ self.cache.popitem(last=False)
+
+ # 添加新项至末尾
+ self.cache[key] = value
+
+ def remove(self, key):
+ if key in self.cache:
+ del self.cache[key]
+
+ def clear(self):
+ self.cache.clear()
+
+ def __len__(self):
+ return len(self.cache)
+
+ def get_all_keys(self):
+ return list(self.cache.keys())
+
+
+class CacheManager:
+ """两级缓存系统:内存LRU缓存 + 磁盘持久化缓存"""
+
_instance = None
_lock = threading.Lock()
- def __new__(cls):
+ def __new__(cls, *args, **kwargs):
with cls._lock:
if cls._instance is None:
- cls._instance = super(PredictionCache, cls).__new__(cls)
+ cls._instance = super(CacheManager, cls).__new__(cls)
return cls._instance
- def __init__(self):
- if not hasattr(self, 'initialized'):
- self.cache_dir = 'cache/predictions'
- self.cache_duration = timedelta(hours=24) # 缓存24小时
- self.cache = {}
- self.cache_queue = queue.Queue()
+ def __init__(self, name="default", memory_capacity=1000, cache_duration=24,
+ disk_cache_dir="cache", flush_interval=5):
+ if hasattr(self, 'initialized'):
+ return
+
+ self.name = name
+ self.memory_cache = LRUCache(memory_capacity)
+ self.disk_cache_dir = os.path.join(disk_cache_dir, name)
+ self.cache_duration = timedelta(hours=cache_duration)
+ self.flush_interval = flush_interval # 定时将内存缓存刷新到磁盘的间隔(分钟)
+ self.cache_stats = {"hits": 0, "misses": 0, "disk_hits": 0}
+ self.disk_queue = queue.Queue()
self.initialized = True
# 确保缓存目录存在
- os.makedirs(self.cache_dir, exist_ok=True)
+ os.makedirs(self.disk_cache_dir, exist_ok=True)
- # 启动缓存清理线程
- self.cleanup_thread = threading.Thread(target=self._cleanup_old_cache, daemon=True)
+ # 启动缓存管理线程
+ self.cleanup_thread = threading.Thread(target=self._cleanup_and_flush_task, daemon=True)
self.cleanup_thread.start()
- # 加载现有缓存
- self._load_cache()
+ # 启动磁盘写入线程
+ self.disk_writer_thread = threading.Thread(target=self._disk_writer_task, daemon=True)
+ self.disk_writer_thread.start()
+
+ logger.info(f"初始化缓存管理器: {name},内存容量: {memory_capacity}项,缓存时间: {cache_duration}小时")
- def _load_cache(self):
- """加载磁盘上的缓存文件"""
- try:
- for filename in os.listdir(self.cache_dir):
- if filename.endswith('.json'):
- filepath = os.path.join(self.cache_dir, filename)
- with open(filepath, 'r', encoding='utf-8') as f:
- cache_data = json.load(f)
- # 检查缓存是否过期
- if self._is_cache_valid(cache_data['timestamp']):
- topic = filename[:-5] # 移除.json后缀
- self.cache[topic] = cache_data
- else:
- # 删除过期缓存文件
- os.remove(filepath)
- except Exception as e:
- print(f"加载缓存失败: {e}")
+ def _get_cache_key(self, key):
+ """标准化缓存键"""
+ if isinstance(key, str):
+ return key
+ return hashlib.md5(str(key).encode()).hexdigest()
- def _cleanup_old_cache(self):
- """定期清理过期缓存的后台线程"""
- while True:
- try:
- # 检查并清理内存缓存
- current_time = datetime.now()
- expired_topics = []
-
- for topic, cache_data in self.cache.items():
- if not self._is_cache_valid(cache_data['timestamp']):
- expired_topics.append(topic)
-
- # 删除过期缓存
- for topic in expired_topics:
- del self.cache[topic]
- cache_file = os.path.join(self.cache_dir, f"{topic}.json")
- if os.path.exists(cache_file):
- os.remove(cache_file)
-
- # 休眠1小时后再次检查
- time.sleep(3600)
- except Exception as e:
- print(f"清理缓存时出错: {e}")
- time.sleep(3600) # 发生错误时也等待1小时
+ def _get_disk_path(self, key):
+ """获取磁盘缓存路径"""
+ safe_key = self._get_cache_key(key)
+ return os.path.join(self.disk_cache_dir, f"{safe_key}.cache")
def _is_cache_valid(self, timestamp):
- """检查缓存是否有效"""
+ """检查缓存是否过期"""
cache_time = datetime.fromtimestamp(timestamp)
return datetime.now() - cache_time < self.cache_duration
- def get(self, topic):
- """获取话题的预测缓存"""
- if topic in self.cache and self._is_cache_valid(self.cache[topic]['timestamp']):
- return self.cache[topic]['prediction']
+ def get(self, key):
+ """获取缓存数据,首先检查内存,然后检查磁盘"""
+ cache_key = self._get_cache_key(key)
+
+ # 1. 检查内存缓存
+ cache_data = self.memory_cache.get(cache_key)
+ if cache_data is not None:
+ if self._is_cache_valid(cache_data['timestamp']):
+ self.cache_stats["hits"] += 1
+ logger.debug(f"内存缓存命中: {key}")
+ return cache_data['data']
+ else:
+ # 过期缓存,从内存中删除
+ self.memory_cache.remove(cache_key)
+
+ # 2. 检查磁盘缓存
+ disk_path = self._get_disk_path(cache_key)
+ if os.path.exists(disk_path):
+ try:
+ with open(disk_path, 'rb') as f:
+ cache_data = pickle.load(f)
+
+ if self._is_cache_valid(cache_data['timestamp']):
+ # 从磁盘加载后,放入内存缓存
+ self.memory_cache.put(cache_key, cache_data)
+ self.cache_stats["disk_hits"] += 1
+ logger.debug(f"磁盘缓存命中: {key}")
+ return cache_data['data']
+ else:
+ # 过期缓存,删除磁盘文件
+ os.remove(disk_path)
+ except Exception as e:
+ logger.warning(f"读取磁盘缓存失败: {key}, 错误: {e}")
+
+ self.cache_stats["misses"] += 1
+ logger.debug(f"缓存未命中: {key}")
return None
- def set(self, topic, prediction):
- """设置话题的预测缓存"""
+ def set(self, key, data, immediate_disk_write=False):
+ """设置缓存数据,同时更新内存和安排磁盘写入"""
+ cache_key = self._get_cache_key(key)
cache_data = {
- 'prediction': prediction,
+ 'data': data,
'timestamp': datetime.now().timestamp()
}
# 更新内存缓存
- self.cache[topic] = cache_data
+ self.memory_cache.put(cache_key, cache_data)
- # 异步保存到磁盘
- self.cache_queue.put((topic, cache_data))
- threading.Thread(target=self._save_cache_to_disk, daemon=True).start()
+ # 安排写入磁盘
+ if immediate_disk_write:
+ self._write_to_disk(cache_key, cache_data)
+ else:
+ self.disk_queue.put((cache_key, cache_data))
+
+ logger.debug(f"缓存已设置: {key}")
+ return True
- def _save_cache_to_disk(self):
- """异步保存缓存到磁盘"""
+ def invalidate(self, key):
+ """使指定键的缓存失效"""
+ cache_key = self._get_cache_key(key)
+
+ # 从内存中删除
+ self.memory_cache.remove(cache_key)
+
+ # 从磁盘中删除
+ disk_path = self._get_disk_path(cache_key)
+ if os.path.exists(disk_path):
+ try:
+ os.remove(disk_path)
+ logger.debug(f"缓存已失效: {key}")
+ except Exception as e:
+ logger.warning(f"删除磁盘缓存失败: {key}, 错误: {e}")
+
+ return True
+
+ def clear_all(self):
+ """清除所有缓存"""
+ # 清除内存缓存
+ self.memory_cache.clear()
+
+ # 清除磁盘缓存
try:
- while not self.cache_queue.empty():
- topic, cache_data = self.cache_queue.get()
- cache_file = os.path.join(self.cache_dir, f"{topic}.json")
- with open(cache_file, 'w', encoding='utf-8') as f:
- json.dump(cache_data, f, ensure_ascii=False, indent=2)
+ shutil.rmtree(self.disk_cache_dir)
+ os.makedirs(self.disk_cache_dir, exist_ok=True)
+ logger.info(f"所有缓存已清除: {self.name}")
except Exception as e:
- print(f"保存缓存到磁盘失败: {e}")
+ logger.error(f"清除磁盘缓存失败: {e}")
+
+ # 重置统计信息
+ self.cache_stats = {"hits": 0, "misses": 0, "disk_hits": 0}
+
+ return True
+
+ def get_stats(self):
+ """获取缓存统计信息"""
+ total_requests = self.cache_stats["hits"] + self.cache_stats["misses"]
+ hit_rate = (self.cache_stats["hits"] / total_requests * 100) if total_requests > 0 else 0
+ total_hits = self.cache_stats["hits"] + self.cache_stats["disk_hits"]
+
+ memory_size = len(self.memory_cache)
+ disk_size = len([f for f in os.listdir(self.disk_cache_dir) if f.endswith('.cache')])
+
+ return {
+ "name": self.name,
+ "memory_items": memory_size,
+ "disk_items": disk_size,
+ "memory_hits": self.cache_stats["hits"],
+ "disk_hits": self.cache_stats["disk_hits"],
+ "misses": self.cache_stats["misses"],
+ "total_requests": total_requests,
+ "hit_rate": hit_rate,
+ "two_level_hit_rate": (total_hits / total_requests * 100) if total_requests > 0 else 0
+ }
+
+ def _write_to_disk(self, cache_key, cache_data):
+ """将缓存写入磁盘"""
+ disk_path = self._get_disk_path(cache_key)
+ try:
+ with open(disk_path, 'wb') as f:
+ pickle.dump(cache_data, f)
+ return True
+ except Exception as e:
+ logger.warning(f"写入磁盘缓存失败: {cache_key}, 错误: {e}")
+ return False
+
+ def _disk_writer_task(self):
+ """后台线程,负责将缓存写入磁盘"""
+ while True:
+ try:
+ # 尝试从队列获取条目,超时后继续循环
+ try:
+ cache_key, cache_data = self.disk_queue.get(timeout=1)
+ self._write_to_disk(cache_key, cache_data)
+ self.disk_queue.task_done()
+ except queue.Empty:
+ time.sleep(0.1)
+ except Exception as e:
+ logger.error(f"磁盘写入线程出错: {e}")
+ time.sleep(5) # 发生错误时等待一段时间
+
+ def _cleanup_and_flush_task(self):
+ """后台线程,负责清理过期缓存和定期刷新内存缓存到磁盘"""
+ while True:
+ try:
+ # 1. 清理过期的内存缓存
+ current_time = datetime.now()
+ for key in self.memory_cache.get_all_keys():
+ cache_data = self.memory_cache.get(key)
+ if not self._is_cache_valid(cache_data['timestamp']):
+ self.memory_cache.remove(key)
+
+ # 2. 清理过期的磁盘缓存
+ for filename in os.listdir(self.disk_cache_dir):
+ if filename.endswith('.cache'):
+ filepath = os.path.join(self.disk_cache_dir, filename)
+ try:
+ with open(filepath, 'rb') as f:
+ cache_data = pickle.load(f)
+ if not self._is_cache_valid(cache_data['timestamp']):
+ os.remove(filepath)
+ except Exception as e:
+ # 清理损坏的缓存文件
+ logger.warning(f"读取缓存文件失败,将删除: {filepath}, 错误: {e}")
+ os.remove(filepath)
+
+ # 3. 将内存缓存刷新到磁盘
+ # 注意:这会重写已经写入磁盘的缓存,但确保内存和磁盘保持同步
+ for key in self.memory_cache.get_all_keys():
+ cache_data = self.memory_cache.get(key)
+ self._write_to_disk(key, cache_data)
+
+ # 每小时执行一次清理
+ time.sleep(3600)
+ except Exception as e:
+ logger.error(f"缓存清理线程出错: {e}")
+ time.sleep(3600) # 发生错误时也等待一段时间
-# 创建全局缓存实例
-prediction_cache = PredictionCache()
\ No newline at end of file
+
+# 创建不同领域的缓存实例
+prediction_cache = CacheManager(name="predictions", memory_capacity=500, cache_duration=24)
+sentiment_cache = CacheManager(name="sentiment", memory_capacity=1000, cache_duration=12)
+topic_cache = CacheManager(name="topics", memory_capacity=200, cache_duration=6)
+user_data_cache = CacheManager(name="user_data", memory_capacity=300, cache_duration=48)
+
+# 向后兼容的别名
+PredictionCache = CacheManager
+# 为保持向后兼容,我们保留原来的prediction_cache
+prediction_cache_old = prediction_cache
\ No newline at end of file
diff --git a/utils/init_wizard.py b/utils/init_wizard.py
new file mode 100644
index 0000000..8e032fb
--- /dev/null
+++ b/utils/init_wizard.py
@@ -0,0 +1,558 @@
+import os
+import sys
+import json
+import getpass
+import secrets
+import logging
+import platform
+import socket
+import hashlib
+import base64
+import re
+import shutil
+import subprocess
+from pathlib import Path
+from datetime import datetime
+import pymysql
+from dotenv import load_dotenv, set_key, find_dotenv
+
+# 设置日志
+logging.basicConfig(
+ level=logging.INFO,
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
+ handlers=[logging.StreamHandler(sys.stdout)]
+)
+logger = logging.getLogger('init_wizard')
+
+class InitWizard:
+ """
+ 初始化向导 - 简化系统的初始配置流程,并提供安全加固功能
+ """
+
+ def __init__(self):
+ # 加载环境变量
+ load_dotenv()
+
+ # 配置项
+ self.config = {
+ # 数据库配置
+ 'db': {
+ 'host': os.getenv('DB_HOST', 'localhost'),
+ 'port': int(os.getenv('DB_PORT', '3306')),
+ 'user': os.getenv('DB_USER', 'root'),
+ 'password': os.getenv('DB_PASSWORD', ''),
+ 'database': os.getenv('DB_NAME', 'Weibo_PublicOpinion_AnalysisSystem'),
+ 'ssl': bool(os.getenv('DB_SSL', 'false').lower() == 'true')
+ },
+ # Flask应用配置
+ 'app': {
+ 'host': os.getenv('FLASK_HOST', '127.0.0.1'),
+ 'port': int(os.getenv('FLASK_PORT', '5000')),
+ 'secret_key': os.getenv('FLASK_SECRET_KEY', ''),
+ 'enable_https': bool(os.getenv('ENABLE_HTTPS', 'false').lower() == 'true'),
+ 'debug': bool(os.getenv('FLASK_DEBUG', 'false').lower() == 'true')
+ },
+ # API密钥配置
+ 'api_keys': {
+ 'openai': os.getenv('OPENAI_API_KEY', ''),
+ 'anthropic': os.getenv('ANTHROPIC_API_KEY', ''),
+ 'deepseek': os.getenv('DEEPSEEK_API_KEY', '')
+ },
+ # 安全配置
+ 'security': {
+ 'enable_rate_limit': bool(os.getenv('ENABLE_RATE_LIMIT', 'true').lower() == 'true'),
+ 'enable_ip_blocking': bool(os.getenv('ENABLE_IP_BLOCKING', 'true').lower() == 'true'),
+ 'enable_sensitive_data_filter': bool(os.getenv('ENABLE_SENSITIVE_DATA_FILTER', 'true').lower() == 'true'),
+ 'enable_mutual_auth': bool(os.getenv('ENABLE_MUTUAL_AUTH', 'false').lower() == 'true'),
+ 'min_password_length': int(os.getenv('MIN_PASSWORD_LENGTH', '8')),
+ 'session_timeout': int(os.getenv('SESSION_TIMEOUT', '120')), # 分钟
+ },
+ # 爬虫配置
+ 'crawler': {
+ 'interval': int(os.getenv('CRAWL_INTERVAL', '18000')), # 秒
+ 'max_retries': int(os.getenv('CRAWL_MAX_RETRIES', '3')),
+ 'timeout': int(os.getenv('CRAWL_TIMEOUT', '30')),
+ 'max_concurrent': int(os.getenv('CRAWL_MAX_CONCURRENT', '2')),
+ 'user_agent': os.getenv('CRAWL_USER_AGENT', 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36')
+ },
+ # 系统配置
+ 'system': {
+ 'initialized': bool(os.getenv('SYSTEM_INITIALIZED', 'false').lower() == 'true'),
+ 'version': os.getenv('SYSTEM_VERSION', '2.0.0'),
+ 'log_level': os.getenv('LOG_LEVEL', 'INFO'),
+ 'data_dir': os.getenv('DATA_DIR', 'data'),
+ 'temp_dir': os.getenv('TEMP_DIR', 'temp'),
+ 'cache_dir': os.getenv('CACHE_DIR', 'cache'),
+ 'max_model_memory': float(os.getenv('MAX_MODEL_MEMORY_USAGE', '4.0')), # GB
+ }
+ }
+
+ # 安全选项
+ self.security_options = {
+ 'rate_limit': {
+ 'name': '请求速率限制',
+ 'description': '防止API被滥用,限制单个IP的请求频率',
+ 'default': True
+ },
+ 'ip_blocking': {
+ 'name': 'IP黑名单',
+ 'description': '阻止可疑IP访问系统',
+ 'default': True
+ },
+ 'sensitive_data_filter': {
+ 'name': '敏感信息过滤',
+ 'description': '自动识别并屏蔽输出内容中的敏感信息(如手机号、邮箱等)',
+ 'default': True
+ },
+ 'mutual_auth': {
+ 'name': '双向认证',
+ 'description': '要求API调用方提供有效证书,增强API安全性(需要HTTPS)',
+ 'default': False
+ }
+ }
+
+ def start(self):
+ """启动初始化向导"""
+ self._print_welcome()
+
+ if self.config['system']['initialized']:
+ print("\n系统已经初始化过。您想重新配置吗? [y/N]: ", end='')
+ choice = input().strip().lower()
+ if choice != 'y':
+ print("初始化向导已退出。如需重新配置,请设置环境变量 SYSTEM_INITIALIZED=false 或删除 .env 文件。")
+ return
+
+ # 主配置流程
+ try:
+ self._configure_database()
+ self._configure_app()
+ self._configure_api_keys()
+ self._configure_security()
+ self._configure_crawler()
+ self._configure_system()
+
+ # 保存配置
+ self._save_config()
+
+ # 应用安全措施
+ self._apply_security_measures()
+
+ print("\n✅ 初始化完成!系统已成功配置。")
+ print("您现在可以运行 python app.py 启动应用。")
+
+ except KeyboardInterrupt:
+ print("\n\n初始化向导已取消。配置未保存。")
+ except Exception as e:
+ logger.error(f"初始化过程中发生错误: {e}")
+ print(f"\n❌ 初始化失败: {e}")
+ print("请检查错误并重试。")
+
+ def _print_welcome(self):
+ """打印欢迎信息"""
+ print("\n" + "="*80)
+ print(" "*20 + "微博舆情分析预测系统 - 初始化向导 v2.0")
+ print("="*80)
+ print("\n欢迎使用微博舆情分析预测系统!此向导将引导您完成系统的初始配置。")
+ print("按Ctrl+C可随时退出向导。")
+ print("\n系统信息:")
+ print(f" • 操作系统: {platform.system()} {platform.release()}")
+ print(f" • Python版本: {platform.python_version()}")
+ print(f" • 主机名: {socket.gethostname()}")
+ print(f" • 当前时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
+ print("\n让我们开始配置吧!每个选项都有默认值,直接按回车即可使用默认值。")
+ print("-"*80)
+
+ def _configure_database(self):
+ """配置数据库连接"""
+ print("\n📦 数据库配置")
+ print("-"*50)
+
+ # 询问数据库连接信息
+ self.config['db']['host'] = self._prompt(
+ "数据库主机", self.config['db']['host'])
+
+ port_str = self._prompt(
+ "数据库端口", str(self.config['db']['port']))
+ try:
+ self.config['db']['port'] = int(port_str)
+ except ValueError:
+ print(f"端口号无效,使用默认值 {self.config['db']['port']}")
+
+ self.config['db']['user'] = self._prompt(
+ "数据库用户名", self.config['db']['user'])
+
+ # 密码使用getpass以避免明文显示
+ default_pass = '*' * len(self.config['db']['password']) if self.config['db']['password'] else ''
+ password = getpass.getpass(f"数据库密码 [{default_pass}]: ")
+ if password:
+ self.config['db']['password'] = password
+
+ self.config['db']['database'] = self._prompt(
+ "数据库名", self.config['db']['database'])
+
+ ssl_str = self._prompt(
+ "使用SSL连接 (true/false)", str(self.config['db']['ssl']).lower())
+ self.config['db']['ssl'] = ssl_str.lower() == 'true'
+
+ # 测试数据库连接
+ print("\n正在测试数据库连接...")
+ try:
+ self._test_db_connection()
+ print("✅ 数据库连接成功!")
+ except Exception as e:
+ print(f"❌ 数据库连接失败: {e}")
+ retry = input("是否重新配置数据库连接? [Y/n]: ").strip().lower()
+ if retry != 'n':
+ return self._configure_database()
+ else:
+ print("跳过数据库连接测试,但配置可能不正确。")
+
+ def _configure_app(self):
+ """配置Flask应用"""
+ print("\n🚀 应用配置")
+ print("-"*50)
+
+ self.config['app']['host'] = self._prompt(
+ "监听地址 (0.0.0.0表示所有网络接口)", self.config['app']['host'])
+
+ port_str = self._prompt(
+ "监听端口", str(self.config['app']['port']))
+ try:
+ self.config['app']['port'] = int(port_str)
+ except ValueError:
+ print(f"端口号无效,使用默认值 {self.config['app']['port']}")
+
+ # 自动生成密钥
+ if not self.config['app']['secret_key']:
+ self.config['app']['secret_key'] = secrets.token_hex(32)
+ print(f"已自动生成应用密钥: {self.config['app']['secret_key'][:8]}...")
+ else:
+ regenerate = input("应用密钥已存在。是否重新生成? [y/N]: ").strip().lower()
+ if regenerate == 'y':
+ self.config['app']['secret_key'] = secrets.token_hex(32)
+ print(f"已重新生成应用密钥: {self.config['app']['secret_key'][:8]}...")
+
+ https_str = self._prompt(
+ "启用HTTPS (true/false)", str(self.config['app']['enable_https']).lower())
+ self.config['app']['enable_https'] = https_str.lower() == 'true'
+
+ debug_str = self._prompt(
+ "启用调试模式 (true/false, 生产环境建议false)", str(self.config['app']['debug']).lower())
+ self.config['app']['debug'] = debug_str.lower() == 'true'
+
+ def _configure_api_keys(self):
+ """配置API密钥"""
+ print("\n🔑 API密钥配置")
+ print("-"*50)
+ print("系统支持多个大语言模型,至少需要配置一个API密钥。")
+
+ # 配置OpenAI API密钥
+ has_openai = self._prompt(
+ "是否配置OpenAI API密钥? (y/n)", "y" if self.config['api_keys']['openai'] else "n")
+ if has_openai.lower() == 'y':
+ self.config['api_keys']['openai'] = self._prompt(
+ "OpenAI API密钥", self.config['api_keys']['openai'])
+
+ # 配置Anthropic API密钥
+ has_anthropic = self._prompt(
+ "是否配置Anthropic (Claude) API密钥? (y/n)", "y" if self.config['api_keys']['anthropic'] else "n")
+ if has_anthropic.lower() == 'y':
+ self.config['api_keys']['anthropic'] = self._prompt(
+ "Anthropic API密钥", self.config['api_keys']['anthropic'])
+
+ # 配置DeepSeek API密钥
+ has_deepseek = self._prompt(
+ "是否配置DeepSeek API密钥? (y/n)", "y" if self.config['api_keys']['deepseek'] else "n")
+ if has_deepseek.lower() == 'y':
+ self.config['api_keys']['deepseek'] = self._prompt(
+ "DeepSeek API密钥", self.config['api_keys']['deepseek'])
+
+ # 检查是否至少配置了一个API密钥
+ if not (self.config['api_keys']['openai'] or self.config['api_keys']['anthropic'] or self.config['api_keys']['deepseek']):
+ print("⚠️ 警告: 您未配置任何API密钥,系统的AI分析功能将不可用。")
+ confirm = input("是否继续? [Y/n]: ").strip().lower()
+ if confirm == 'n':
+ return self._configure_api_keys()
+
+ def _configure_security(self):
+ """配置安全设置"""
+ print("\n🔒 安全配置")
+ print("-"*50)
+
+ for key, option in self.security_options.items():
+ current_value = self.config['security'][f'enable_{key}']
+ print(f"\n{option['name']}: {option['description']}")
+ enable_str = self._prompt(
+ f"启用{option['name']} (true/false)", str(current_value).lower())
+ self.config['security'][f'enable_{key}'] = enable_str.lower() == 'true'
+
+ # 密码安全策略
+ min_len_str = self._prompt(
+ "最小密码长度 (推荐不低于8)", str(self.config['security']['min_password_length']))
+ try:
+ self.config['security']['min_password_length'] = int(min_len_str)
+ if self.config['security']['min_password_length'] < 6:
+ print("⚠️ 警告: 短密码容易被暴力破解,建议设置更长的密码。")
+ except ValueError:
+ print(f"无效输入,使用默认值 {self.config['security']['min_password_length']}")
+
+ # 会话超时设置
+ timeout_str = self._prompt(
+ "会话超时时间 (分钟)", str(self.config['security']['session_timeout']))
+ try:
+ self.config['security']['session_timeout'] = int(timeout_str)
+ except ValueError:
+ print(f"无效输入,使用默认值 {self.config['security']['session_timeout']}")
+
+ def _configure_crawler(self):
+ """配置爬虫设置"""
+ print("\n🕷️ 爬虫配置")
+ print("-"*50)
+
+ interval_str = self._prompt(
+ "爬取间隔 (秒)", str(self.config['crawler']['interval']))
+ try:
+ self.config['crawler']['interval'] = int(interval_str)
+ except ValueError:
+ print(f"无效输入,使用默认值 {self.config['crawler']['interval']}")
+
+ retries_str = self._prompt(
+ "最大重试次数", str(self.config['crawler']['max_retries']))
+ try:
+ self.config['crawler']['max_retries'] = int(retries_str)
+ except ValueError:
+ print(f"无效输入,使用默认值 {self.config['crawler']['max_retries']}")
+
+ timeout_str = self._prompt(
+ "超时时间 (秒)", str(self.config['crawler']['timeout']))
+ try:
+ self.config['crawler']['timeout'] = int(timeout_str)
+ except ValueError:
+ print(f"无效输入,使用默认值 {self.config['crawler']['timeout']}")
+
+ concurrent_str = self._prompt(
+ "最大并发数", str(self.config['crawler']['max_concurrent']))
+ try:
+ self.config['crawler']['max_concurrent'] = int(concurrent_str)
+ except ValueError:
+ print(f"无效输入,使用默认值 {self.config['crawler']['max_concurrent']}")
+
+ self.config['crawler']['user_agent'] = self._prompt(
+ "User-Agent", self.config['crawler']['user_agent'])
+
+ def _configure_system(self):
+ """配置系统设置"""
+ print("\n⚙️ 系统配置")
+ print("-"*50)
+
+ # 日志级别
+ log_levels = ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL']
+ current_level = self.config['system']['log_level']
+ print(f"可选日志级别: {', '.join(log_levels)}")
+ log_level = self._prompt("日志级别", current_level).upper()
+ if log_level in log_levels:
+ self.config['system']['log_level'] = log_level
+ else:
+ print(f"无效的日志级别,使用默认值 {current_level}")
+
+ # 数据目录
+ data_dir = self._prompt("数据目录", self.config['system']['data_dir'])
+ if data_dir:
+ self.config['system']['data_dir'] = data_dir
+ os.makedirs(data_dir, exist_ok=True)
+ print(f"已创建数据目录: {data_dir}")
+
+ # 缓存目录
+ cache_dir = self._prompt("缓存目录", self.config['system']['cache_dir'])
+ if cache_dir:
+ self.config['system']['cache_dir'] = cache_dir
+ os.makedirs(cache_dir, exist_ok=True)
+ print(f"已创建缓存目录: {cache_dir}")
+
+ # 临时目录
+ temp_dir = self._prompt("临时文件目录", self.config['system']['temp_dir'])
+ if temp_dir:
+ self.config['system']['temp_dir'] = temp_dir
+ os.makedirs(temp_dir, exist_ok=True)
+ print(f"已创建临时文件目录: {temp_dir}")
+
+ # 模型内存限制
+ memory_str = self._prompt(
+ "最大模型内存使用量 (GB)", str(self.config['system']['max_model_memory']))
+ try:
+ self.config['system']['max_model_memory'] = float(memory_str)
+ except ValueError:
+ print(f"无效输入,使用默认值 {self.config['system']['max_model_memory']}")
+
+ # 标记系统已初始化
+ self.config['system']['initialized'] = True
+
+ def _save_config(self):
+ """保存配置到.env文件"""
+ print("\n正在保存配置...")
+
+ # 构建.env文件内容
+ env_content = [
+ "# 微博舆情分析预测系统配置文件",
+ f"# 生成时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}",
+ "",
+ "# 数据库配置",
+ f"DB_HOST={self.config['db']['host']}",
+ f"DB_PORT={self.config['db']['port']}",
+ f"DB_USER={self.config['db']['user']}",
+ f"DB_PASSWORD={self.config['db']['password']}",
+ f"DB_NAME={self.config['db']['database']}",
+ f"DB_SSL={str(self.config['db']['ssl']).lower()}",
+ "",
+ "# 应用配置",
+ f"FLASK_HOST={self.config['app']['host']}",
+ f"FLASK_PORT={self.config['app']['port']}",
+ f"FLASK_SECRET_KEY={self.config['app']['secret_key']}",
+ f"ENABLE_HTTPS={str(self.config['app']['enable_https']).lower()}",
+ f"FLASK_DEBUG={str(self.config['app']['debug']).lower()}",
+ "",
+ "# API密钥",
+ f"OPENAI_API_KEY={self.config['api_keys']['openai']}",
+ f"ANTHROPIC_API_KEY={self.config['api_keys']['anthropic']}",
+ f"DEEPSEEK_API_KEY={self.config['api_keys']['deepseek']}",
+ "",
+ "# 安全配置",
+ f"ENABLE_RATE_LIMIT={str(self.config['security']['enable_rate_limit']).lower()}",
+ f"ENABLE_IP_BLOCKING={str(self.config['security']['enable_ip_blocking']).lower()}",
+ f"ENABLE_SENSITIVE_DATA_FILTER={str(self.config['security']['enable_sensitive_data_filter']).lower()}",
+ f"ENABLE_MUTUAL_AUTH={str(self.config['security']['enable_mutual_auth']).lower()}",
+ f"MIN_PASSWORD_LENGTH={self.config['security']['min_password_length']}",
+ f"SESSION_TIMEOUT={self.config['security']['session_timeout']}",
+ "",
+ "# 爬虫配置",
+ f"CRAWL_INTERVAL={self.config['crawler']['interval']}",
+ f"CRAWL_MAX_RETRIES={self.config['crawler']['max_retries']}",
+ f"CRAWL_TIMEOUT={self.config['crawler']['timeout']}",
+ f"CRAWL_MAX_CONCURRENT={self.config['crawler']['max_concurrent']}",
+ f"CRAWL_USER_AGENT={self.config['crawler']['user_agent']}",
+ "",
+ "# 系统配置",
+ f"SYSTEM_INITIALIZED={str(self.config['system']['initialized']).lower()}",
+ f"SYSTEM_VERSION={self.config['system']['version']}",
+ f"LOG_LEVEL={self.config['system']['log_level']}",
+ f"DATA_DIR={self.config['system']['data_dir']}",
+ f"TEMP_DIR={self.config['system']['temp_dir']}",
+ f"CACHE_DIR={self.config['system']['cache_dir']}",
+ f"MAX_MODEL_MEMORY_USAGE={self.config['system']['max_model_memory']}",
+ ]
+
+ # 写入.env文件
+ with open('.env', 'w') as f:
+ f.write('\n'.join(env_content))
+
+ print("✅ 配置已保存到 .env 文件")
+
+ # 创建备份
+ backup_path = f".env.backup.{datetime.now().strftime('%Y%m%d%H%M%S')}"
+ shutil.copy2('.env', backup_path)
+ print(f"✅ 配置备份已保存到 {backup_path}")
+
+ def _test_db_connection(self):
+ """测试数据库连接"""
+ connection = pymysql.connect(
+ host=self.config['db']['host'],
+ port=self.config['db']['port'],
+ user=self.config['db']['user'],
+ password=self.config['db']['password'],
+ database=self.config['db']['database'],
+ charset='utf8mb4',
+ ssl={'ssl': {'ca': None}} if self.config['db']['ssl'] else None
+ )
+ connection.close()
+
+ def _apply_security_measures(self):
+ """应用安全措施"""
+ print("\n正在应用安全措施...")
+
+ # 创建相关目录
+ security_dir = os.path.join(self.config['system']['data_dir'], 'security')
+ os.makedirs(security_dir, exist_ok=True)
+
+ # 设置文件权限
+ try:
+ # 仅在类Unix系统上设置文件权限
+ if platform.system() != "Windows":
+ os.chmod('.env', 0o600) # 只有所有者可读写
+ print("✅ 已设置.env文件权限为600 (只有所有者可读写)")
+ except Exception as e:
+ logger.warning(f"设置文件权限失败: {e}")
+
+ # 生成密钥对(如果启用了双向认证)
+ if self.config['security']['enable_mutual_auth']:
+ cert_dir = os.path.join(security_dir, 'certs')
+ os.makedirs(cert_dir, exist_ok=True)
+
+ try:
+ # 检查是否有OpenSSL可用
+ subprocess.run(['openssl', 'version'], check=True, capture_output=True)
+
+ # 生成自签名证书
+ key_file = os.path.join(cert_dir, 'server.key')
+ cert_file = os.path.join(cert_dir, 'server.crt')
+
+ if not os.path.exists(key_file) or not os.path.exists(cert_file):
+ print("正在生成SSL证书...")
+ subprocess.run([
+ 'openssl', 'req', '-x509', '-newkey', 'rsa:4096',
+ '-keyout', key_file, '-out', cert_file,
+ '-days', '365', '-nodes',
+ '-subj', '/CN=localhost'
+ ], check=True)
+ print(f"✅ SSL证书已生成: {cert_file}")
+ except subprocess.CalledProcessError:
+ print("⚠️ OpenSSL不可用,无法生成SSL证书。如需使用HTTPS,请手动配置证书。")
+ except Exception as e:
+ logger.warning(f"生成SSL证书失败: {e}")
+
+ # 创建敏感信息过滤器配置
+ if self.config['security']['enable_sensitive_data_filter']:
+ filter_config = {
+ 'enabled': True,
+ 'patterns': {
+ 'phone': r'\b1[3-9]\d{9}\b',
+ 'email': r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b',
+ 'id_card': r'\b[1-9]\d{5}(19|20)\d{2}(0[1-9]|1[0-2])(0[1-9]|[12]\d|3[01])\d{3}[\dXx]\b',
+ 'credit_card': r'\b\d{4}[ -]?\d{4}[ -]?\d{4}[ -]?\d{4}\b',
+ 'address': r'(北京|上海|广州|深圳|天津|重庆|南京|杭州|武汉|成都|西安)市.*?(路|街|道|巷).*?(号)'
+ },
+ 'replacements': {
+ 'phone': '***********',
+ 'email': '******@*****',
+ 'id_card': '******************',
+ 'credit_card': '****************',
+ 'address': '[地址已隐藏]'
+ }
+ }
+
+ filter_path = os.path.join(security_dir, 'sensitive_filter.json')
+ with open(filter_path, 'w', encoding='utf-8') as f:
+ json.dump(filter_config, f, ensure_ascii=False, indent=2)
+
+ print(f"✅ 敏感信息过滤器配置已保存到 {filter_path}")
+
+ # 创建IP黑名单文件
+ if self.config['security']['enable_ip_blocking']:
+ blacklist_path = os.path.join(security_dir, 'ip_blacklist.txt')
+ if not os.path.exists(blacklist_path):
+ with open(blacklist_path, 'w') as f:
+ f.write("# 每行一个IP地址\n")
+ print(f"✅ IP黑名单文件已创建: {blacklist_path}")
+
+ def _prompt(self, prompt, default=""):
+ """提示用户输入,如果用户直接按回车则返回默认值"""
+ if default:
+ user_input = input(f"{prompt} [{default}]: ").strip()
+ else:
+ user_input = input(f"{prompt}: ").strip()
+
+ return user_input if user_input else default
+
+
+if __name__ == "__main__":
+ wizard = InitWizard()
+ wizard.start()
\ No newline at end of file
diff --git a/utils/model_loader.py b/utils/model_loader.py
new file mode 100644
index 0000000..77bc111
--- /dev/null
+++ b/utils/model_loader.py
@@ -0,0 +1,285 @@
+import os
+import sys
+import pickle
+import marshal
+import types
+import logging
+import torch
+import numpy as np
+import json
+from pathlib import Path
+
+logger = logging.getLogger('model_loader')
+logger.setLevel(logging.INFO)
+
+def load_sentiment_model(model_path, device=None):
+ """
+ 加载情感分析模型
+
+ 参数:
+ model_path: 模型文件路径
+ device: 设备(可忽略,marshal模型不依赖设备)
+
+ 返回:
+ 加载好的模型对象
+ """
+ try:
+ logger.info(f"加载情感分析模型: {model_path}")
+
+ if model_path.endswith('.marshal') or model_path.endswith('.marshal.3'):
+ with open(model_path, 'rb') as f:
+ model_data = marshal.load(f)
+
+ # 将marshal数据转换为可调用的函数对象
+ sentiment_func = types.FunctionType(model_data, globals(), "sentiment_func")
+ logger.info("情感分析模型加载成功")
+ return sentiment_func
+ else:
+ raise ValueError(f"不支持的情感模型格式: {model_path}")
+ except Exception as e:
+ logger.error(f"加载情感分析模型失败: {e}")
+ raise
+
+def load_bert_ctm_model(model_dir, device='cuda' if torch.cuda.is_available() else 'cpu'):
+ """
+ 加载BERT-CTM模型
+
+ 参数:
+ model_dir: 模型目录
+ device: 计算设备
+
+ 返回:
+ 包含模型和分词器的字典
+ """
+ try:
+ logger.info(f"加载BERT-CTM模型: {model_dir}")
+
+ sys.path.append('model_pro')
+ from BERT_CTM import BERT_CTM
+ from transformers import BertTokenizer
+
+ # 加载模型
+ model_path = os.path.join(model_dir, 'final_model.pt') if not model_dir.endswith('.pt') else model_dir
+ model = BERT_CTM()
+ model.load_state_dict(torch.load(model_path, map_location=device))
+ model.to(device)
+ model.eval()
+
+ # 加载分词器
+ tokenizer_path = os.path.join(os.path.dirname(model_dir), 'bert_model')
+ tokenizer = BertTokenizer.from_pretrained(tokenizer_path)
+
+ logger.info("BERT-CTM模型加载成功")
+ return {
+ 'model': model,
+ 'tokenizer': tokenizer,
+ 'device': device
+ }
+ except Exception as e:
+ logger.error(f"加载BERT-CTM模型失败: {e}")
+ raise
+
+def load_bcat_model(model_dir, device='cuda' if torch.cuda.is_available() else 'cpu'):
+ """
+ 加载BCAT模型
+
+ 参数:
+ model_dir: 模型目录
+ device: 计算设备
+
+ 返回:
+ 包含模型和分词器的字典
+ """
+ try:
+ logger.info(f"加载BCAT模型: {model_dir}")
+
+ sys.path.append('model_pro')
+ from BCAT import BCAT
+ from transformers import BertTokenizer
+
+ # 加载模型配置
+ config_path = os.path.join(model_dir, 'config.json')
+ with open(config_path, 'r', encoding='utf-8') as f:
+ config = json.load(f)
+
+ # 初始化模型
+ model = BCAT(**config)
+
+ # 加载模型权重
+ model_path = os.path.join(model_dir, 'model.pt')
+ model.load_state_dict(torch.load(model_path, map_location=device))
+ model.to(device)
+ model.eval()
+
+ # 加载分词器
+ tokenizer_path = os.path.join(model_dir, 'tokenizer')
+ tokenizer = BertTokenizer.from_pretrained(tokenizer_path)
+
+ logger.info("BCAT模型加载成功")
+ return {
+ 'model': model,
+ 'tokenizer': tokenizer,
+ 'device': device,
+ 'config': config
+ }
+ except Exception as e:
+ logger.error(f"加载BCAT模型失败: {e}")
+ raise
+
+def load_topic_classifier(model_dir, device='cuda' if torch.cuda.is_available() else 'cpu'):
+ """
+ 加载话题分类模型
+
+ 参数:
+ model_dir: 模型目录
+ device: 计算设备
+
+ 返回:
+ 包含模型、分词器和标签映射的字典
+ """
+ try:
+ logger.info(f"加载话题分类模型: {model_dir}")
+
+ # 尝试加载transformers模型
+ try:
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
+
+ # 加载模型
+ model = AutoModelForSequenceClassification.from_pretrained(model_dir)
+ model.to(device)
+ model.eval()
+
+ # 加载分词器
+ tokenizer = AutoTokenizer.from_pretrained(model_dir)
+
+ # 加载标签映射
+ labels_path = os.path.join(model_dir, 'labels.json')
+ if os.path.exists(labels_path):
+ with open(labels_path, 'r', encoding='utf-8') as f:
+ labels_map = json.load(f)
+ else:
+ # 尝试从config中读取标签
+ if hasattr(model.config, 'id2label'):
+ labels_map = model.config.id2label
+ else:
+ labels_map = {}
+
+ logger.info("话题分类模型加载成功 (transformers)")
+ return {
+ 'model': model,
+ 'tokenizer': tokenizer,
+ 'labels_map': labels_map,
+ 'device': device
+ }
+ except Exception as e:
+ logger.warning(f"使用transformers加载失败,尝试其他方法: {e}")
+
+ # 尝试加载PyTorch模型
+ model_path = os.path.join(model_dir, 'model.pt')
+ if os.path.exists(model_path):
+ model = torch.load(model_path, map_location=device)
+
+ # 加载分词器
+ tokenizer_path = os.path.join(model_dir, 'tokenizer.pkl')
+ if os.path.exists(tokenizer_path):
+ with open(tokenizer_path, 'rb') as f:
+ tokenizer = pickle.load(f)
+ else:
+ tokenizer = None
+
+ # 加载标签映射
+ labels_path = os.path.join(model_dir, 'labels.json')
+ if os.path.exists(labels_path):
+ with open(labels_path, 'r', encoding='utf-8') as f:
+ labels_map = json.load(f)
+ else:
+ labels_map = {}
+
+ logger.info("话题分类模型加载成功 (PyTorch)")
+ return {
+ 'model': model,
+ 'tokenizer': tokenizer,
+ 'labels_map': labels_map,
+ 'device': device
+ }
+
+ raise ValueError(f"无法加载模型: {model_dir}")
+ except Exception as e:
+ logger.error(f"加载话题分类模型失败: {e}")
+ raise
+
+def load_echarts_optimizer():
+ """
+ 加载ECharts优化器,用于提升大数据渲染性能
+
+ 返回:
+ ECharts优化器对象
+ """
+ try:
+ class EChartsOptimizer:
+ def __init__(self):
+ self.chunk_size = 1000 # 分块大小
+ logger.info("ECharts优化器初始化成功")
+
+ def optimize_option(self, option):
+ """优化ECharts配置,提升大数据渲染性能"""
+ if not option:
+ return option
+
+ # 深拷贝以避免修改原始对象
+ import copy
+ option = copy.deepcopy(option)
+
+ # 添加渐进式渲染
+ if 'progressive' not in option:
+ option['progressive'] = 300 # 每帧渲染的数据点数量
+
+ if 'progressiveThreshold' not in option:
+ option['progressiveThreshold'] = 5000 # 启动渐进式渲染的阈值
+
+ if 'series' in option and isinstance(option['series'], list):
+ for series in option['series']:
+ # 对大数据系列应用优化
+ if 'data' in series and isinstance(series['data'], list) and len(series['data']) > 5000:
+ # 大数据采样
+ if series.get('type') in ['scatter', 'line']:
+ self._optimize_large_data_series(series)
+
+ return option
+
+ def _optimize_large_data_series(self, series):
+ """优化大数据系列"""
+ # 添加大数据优化选项
+ series['large'] = True
+ series['largeThreshold'] = 2000
+
+ # 按需设置抽样
+ if len(series['data']) > 50000:
+ # 对非常大的数据集进行抽样
+ step = max(1, len(series['data']) // 50000)
+ series['data'] = series['data'][::step]
+ series['sampling'] = 'average'
+
+ return series
+
+ def chunk_process_data(self, data, process_func):
+ """分块处理大数据"""
+ result = []
+ for i in range(0, len(data), self.chunk_size):
+ chunk = data[i:i + self.chunk_size]
+ result.extend(process_func(chunk))
+ return result
+
+ return EChartsOptimizer()
+ except Exception as e:
+ logger.error(f"加载ECharts优化器失败: {e}")
+ return None
+
+# 导出所有加载函数
+__all__ = [
+ 'load_sentiment_model',
+ 'load_bert_ctm_model',
+ 'load_bcat_model',
+ 'load_topic_classifier',
+ 'load_echarts_optimizer'
+]
\ No newline at end of file
diff --git a/utils/model_manager.py b/utils/model_manager.py
new file mode 100644
index 0000000..398acb4
--- /dev/null
+++ b/utils/model_manager.py
@@ -0,0 +1,364 @@
+import os
+import time
+import threading
+import logging
+import gc
+import torch
+import numpy as np
+from collections import OrderedDict
+from datetime import datetime, timedelta
+
+logger = logging.getLogger('model_manager')
+logger.setLevel(logging.INFO)
+
+class ModelManager:
+ """
+ 模型管理器 - 实现模型预加载和按需卸载技术
+
+ 功能:
+ 1. 预加载经常使用的模型,减少加载等待时间
+ 2. 使用LRU (Least Recently Used) 策略管理内存中加载的模型
+ 3. 支持模型的异步加载和监控
+ 4. 自动检测并释放长时间未使用的模型内存
+ 5. 提供模型使用统计
+ """
+
+ _instance = None
+ _lock = threading.Lock()
+
+ def __new__(cls):
+ with cls._lock:
+ if cls._instance is None:
+ cls._instance = super(ModelManager, cls).__new__(cls)
+ return cls._instance
+
+ def __init__(self):
+ if hasattr(self, 'initialized'):
+ return
+
+ # 已加载模型的缓存,使用OrderedDict实现LRU
+ self.loaded_models = OrderedDict()
+ # 模型使用统计
+ self.model_stats = {}
+ # 模型预热配置
+ self.preload_config = {}
+ # 最大内存占用(GB)
+ self.max_memory_usage = float(os.getenv('MAX_MODEL_MEMORY_USAGE', '4.0'))
+ # 模型加载中的锁
+ self.loading_locks = {}
+ # 模型卸载超时(分钟)
+ self.unload_timeout = int(os.getenv('MODEL_UNLOAD_TIMEOUT', '30'))
+
+ # 启动模型监控线程
+ self.monitor_thread = threading.Thread(target=self._monitor_models, daemon=True)
+ self.monitor_thread.start()
+
+ self.initialized = True
+ logger.info(f"模型管理器初始化完成,最大内存占用: {self.max_memory_usage}GB")
+
+ def register_model(self, model_id, model_path, preload=False, model_size_gb=0.5,
+ load_function=None, device='cuda' if torch.cuda.is_available() else 'cpu'):
+ """
+ 注册模型,可选设置为预加载
+
+ 参数:
+ model_id: 模型唯一标识符
+ model_path: 模型路径
+ preload: 是否预加载
+ model_size_gb: 模型估计大小(GB)
+ load_function: 自定义加载函数,签名为 load_function(model_path, device) -> model
+ device: 加载模型的设备
+ """
+ self.preload_config[model_id] = {
+ 'model_path': model_path,
+ 'preload': preload,
+ 'model_size_gb': model_size_gb,
+ 'load_function': load_function,
+ 'device': device
+ }
+
+ self.model_stats[model_id] = {
+ 'load_count': 0,
+ 'use_count': 0,
+ 'total_load_time': 0,
+ 'last_used': None,
+ 'avg_load_time': 0
+ }
+
+ if preload:
+ logger.info(f"模型 {model_id} 已注册并标记为预加载")
+ # 启动预加载线程
+ threading.Thread(target=self._preload_model, args=(model_id,), daemon=True).start()
+ else:
+ logger.info(f"模型 {model_id} 已注册")
+
+ return True
+
+ def get_model(self, model_id):
+ """
+ 获取模型,如果未加载则加载
+
+ 参数:
+ model_id: 模型唯一标识符
+
+ 返回:
+ 加载好的模型对象
+ """
+ if model_id not in self.preload_config:
+ raise ValueError(f"模型 {model_id} 未注册")
+
+ # 更新最后使用时间
+ self.model_stats[model_id]['last_used'] = datetime.now()
+ self.model_stats[model_id]['use_count'] += 1
+
+ # 检查模型是否已加载
+ if model_id in self.loaded_models:
+ # 将模型移至OrderedDict末尾,表示最近使用
+ model = self.loaded_models.pop(model_id)
+ self.loaded_models[model_id] = model
+ logger.debug(f"使用已加载的模型: {model_id}")
+ return model
+
+ # 获取模型加载锁,防止并发加载同一模型
+ if model_id not in self.loading_locks:
+ self.loading_locks[model_id] = threading.Lock()
+
+ # 加锁加载模型
+ with self.loading_locks[model_id]:
+ # 再次检查模型是否已被其他线程加载
+ if model_id in self.loaded_models:
+ return self.loaded_models[model_id]
+
+ # 检查是否有足够内存
+ self._ensure_memory_available(self.preload_config[model_id]['model_size_gb'])
+
+ # 加载模型
+ start_time = time.time()
+ model = self._load_model(model_id)
+ load_time = time.time() - start_time
+
+ # 更新统计
+ self.model_stats[model_id]['load_count'] += 1
+ self.model_stats[model_id]['total_load_time'] += load_time
+ self.model_stats[model_id]['avg_load_time'] = (
+ self.model_stats[model_id]['total_load_time'] /
+ self.model_stats[model_id]['load_count']
+ )
+
+ logger.info(f"模型 {model_id} 加载完成,耗时: {load_time:.2f}秒")
+
+ # 存储模型
+ self.loaded_models[model_id] = model
+ return model
+
+ def unload_model(self, model_id):
+ """
+ 手动卸载模型
+
+ 参数:
+ model_id: 模型唯一标识符
+ """
+ if model_id in self.loaded_models:
+ logger.info(f"手动卸载模型: {model_id}")
+ del self.loaded_models[model_id]
+ # 强制垃圾回收
+ gc.collect()
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ return True
+ return False
+
+ def get_model_stats(self):
+ """获取所有模型的使用统计"""
+ result = {}
+ for model_id, stats in self.model_stats.items():
+ is_loaded = model_id in self.loaded_models
+ result[model_id] = {
+ **stats,
+ 'is_loaded': is_loaded,
+ 'preload': self.preload_config[model_id]['preload'],
+ 'model_size_gb': self.preload_config[model_id]['model_size_gb'],
+ 'device': self.preload_config[model_id]['device'],
+ }
+ return result
+
+ def preload_all(self):
+ """预加载所有标记为预加载的模型"""
+ for model_id, config in self.preload_config.items():
+ if config['preload'] and model_id not in self.loaded_models:
+ threading.Thread(target=self._preload_model, args=(model_id,), daemon=True).start()
+
+ def _preload_model(self, model_id):
+ """预加载单个模型的内部方法"""
+ try:
+ logger.info(f"开始预加载模型: {model_id}")
+ # 确保有足够内存
+ self._ensure_memory_available(self.preload_config[model_id]['model_size_gb'])
+
+ # 加载模型
+ start_time = time.time()
+ model = self._load_model(model_id)
+ load_time = time.time() - start_time
+
+ # 更新统计
+ self.model_stats[model_id]['load_count'] += 1
+ self.model_stats[model_id]['total_load_time'] += load_time
+ self.model_stats[model_id]['avg_load_time'] = (
+ self.model_stats[model_id]['total_load_time'] /
+ self.model_stats[model_id]['load_count']
+ )
+
+ # 存储模型
+ self.loaded_models[model_id] = model
+ logger.info(f"模型 {model_id} 预加载完成,耗时: {load_time:.2f}秒")
+
+ except Exception as e:
+ logger.error(f"预加载模型 {model_id} 失败: {e}")
+
+ def _load_model(self, model_id):
+ """加载模型的内部方法"""
+ config = self.preload_config[model_id]
+
+ if config['load_function'] is not None:
+ # 使用自定义加载函数
+ return config['load_function'](config['model_path'], config['device'])
+
+ # 默认加载逻辑 - 根据文件扩展名确定加载方式
+ model_path = config['model_path']
+ device = config['device']
+
+ if model_path.endswith('.pt') or model_path.endswith('.pth'):
+ # PyTorch模型
+ return torch.load(model_path, map_location=device)
+ elif model_path.endswith('.pkl'):
+ # Pickle模型
+ import pickle
+ with open(model_path, 'rb') as f:
+ return pickle.load(f)
+ else:
+ # 尝试作为目录加载
+ if os.path.isdir(model_path):
+ # 如果是目录,尝试加载预训练模型
+ try:
+ from transformers import AutoModel, AutoTokenizer
+ model = AutoModel.from_pretrained(model_path)
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
+ return {'model': model.to(device), 'tokenizer': tokenizer}
+ except ImportError:
+ logger.error("transformers库未安装,无法加载预训练模型")
+ raise
+ except Exception as e:
+ logger.error(f"加载预训练模型失败: {e}")
+ raise
+
+ raise ValueError(f"无法确定如何加载模型: {model_path}")
+
+ def _ensure_memory_available(self, required_gb):
+ """确保有足够的内存来加载新模型"""
+ # 如果当前没有加载的模型,直接返回
+ if not self.loaded_models:
+ return
+
+ # 计算当前已加载模型的总内存
+ current_usage = sum(
+ self.preload_config[model_id]['model_size_gb']
+ for model_id in self.loaded_models
+ )
+
+ # 如果添加新模型后超过限制,需要卸载一些模型
+ while current_usage + required_gb > self.max_memory_usage and self.loaded_models:
+ # 卸载最久未使用的模型(OrderedDict的首项)
+ oldest_model_id, _ = next(iter(self.loaded_models.items()))
+ # 检查是否是预加载且最近使用过的模型
+ if (self.preload_config[oldest_model_id]['preload'] and
+ self.model_stats[oldest_model_id]['last_used'] and
+ (datetime.now() - self.model_stats[oldest_model_id]['last_used']) <
+ timedelta(minutes=self.unload_timeout)):
+ # 跳过预加载且最近使用过的模型
+ # 将该模型移至OrderedDict末尾
+ model = self.loaded_models.pop(oldest_model_id)
+ self.loaded_models[oldest_model_id] = model
+ # 如果所有模型都是预加载的且最近使用过,允许超过限制
+ if len(self.loaded_models) <= 1:
+ break
+ continue
+
+ # 卸载模型并更新内存使用
+ model_size = self.preload_config[oldest_model_id]['model_size_gb']
+ del self.loaded_models[oldest_model_id]
+ current_usage -= model_size
+ logger.info(f"自动卸载模型以释放内存: {oldest_model_id} ({model_size}GB)")
+
+ # 强制垃圾回收
+ gc.collect()
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ def _monitor_models(self):
+ """监控并管理模型的内部线程方法"""
+ while True:
+ try:
+ # 检查长时间未使用的非预加载模型
+ current_time = datetime.now()
+ for model_id in list(self.loaded_models.keys()):
+ if (not self.preload_config[model_id]['preload'] and
+ self.model_stats[model_id]['last_used'] and
+ (current_time - self.model_stats[model_id]['last_used']) >
+ timedelta(minutes=self.unload_timeout)):
+ # 卸载长时间未使用的非预加载模型
+ logger.info(f"卸载长时间未使用的模型: {model_id}")
+ del self.loaded_models[model_id]
+ # 强制垃圾回收
+ gc.collect()
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ # 每5分钟检查一次
+ time.sleep(300)
+ except Exception as e:
+ logger.error(f"模型监控线程出错: {e}")
+ time.sleep(300)
+
+# 创建全局模型管理器实例
+model_manager = ModelManager()
+
+# 注册示例函数
+def register_sentiment_model():
+ """注册情感分析模型示例"""
+ from utils.model_loader import load_sentiment_model # 假设您有一个加载情感模型的函数
+
+ try:
+ model_path = os.path.join('model', 'sentiment.marshal.3')
+ model_manager.register_model(
+ model_id='sentiment_basic',
+ model_path=model_path,
+ preload=True,
+ model_size_gb=0.2,
+ load_function=load_sentiment_model
+ )
+ return True
+ except Exception as e:
+ logger.error(f"注册情感分析模型失败: {e}")
+ return False
+
+def register_bert_model():
+ """注册BERT模型示例"""
+ try:
+ model_path = os.path.join('model_pro', 'bert_model')
+ model_manager.register_model(
+ model_id='bert_classifier',
+ model_path=model_path,
+ preload=True,
+ model_size_gb=0.8
+ )
+ return True
+ except Exception as e:
+ logger.error(f"注册BERT模型失败: {e}")
+ return False
+
+# 自动注册常用模型(在导入时执行)
+try:
+ register_sentiment_model()
+ register_bert_model()
+except Exception as e:
+ logger.error(f"自动注册模型失败: {e}")
\ No newline at end of file
diff --git a/utils/model_router.py b/utils/model_router.py
new file mode 100644
index 0000000..e05f1c1
--- /dev/null
+++ b/utils/model_router.py
@@ -0,0 +1,494 @@
+import os
+import json
+import logging
+import re
+from collections import defaultdict
+import random
+import torch
+import numpy as np
+from datetime import datetime
+from typing import Dict, List, Any, Tuple, Optional, Union, Callable
+
+logger = logging.getLogger('model_router')
+logger.setLevel(logging.INFO)
+
+class ModelRouter:
+ """
+ 模型路由器 - 自动根据内容类型选择最优的AI模型
+
+ 功能:
+ 1. 根据内容类型和任务需求,自动选择最合适的AI模型
+ 2. 支持多种模型供应商和模型类型
+ 3. 考虑性能、成本和准确度等因素进行智能路由
+ 4. 学习和适应用户偏好和使用模式
+ 5. 提供标准化的API接口,支持私有模型集成
+ """
+
+ _instance = None
+
+ def __new__(cls):
+ if cls._instance is None:
+ cls._instance = super(ModelRouter, cls).__new__(cls)
+ cls._instance._initialized = False
+ return cls._instance
+
+ def __init__(self):
+ if self._initialized:
+ return
+
+ # 支持的模型定义
+ self.models = {
+ # OpenAI 模型
+ 'gpt-4o-latest': {
+ 'provider': 'openai',
+ 'capabilities': {
+ 'text_analysis': 0.95,
+ 'sentiment_analysis': 0.92,
+ 'keyword_extraction': 0.90,
+ 'summarization': 0.93,
+ 'classification': 0.89,
+ 'chinese_text': 0.88
+ },
+ 'cost_per_1k': 0.01,
+ 'max_tokens': 128000,
+ 'avg_latency': 2.5, # 秒
+ 'requires_api_key': 'OPENAI_API_KEY'
+ },
+ 'gpt-4o-mini': {
+ 'provider': 'openai',
+ 'capabilities': {
+ 'text_analysis': 0.85,
+ 'sentiment_analysis': 0.82,
+ 'keyword_extraction': 0.80,
+ 'summarization': 0.84,
+ 'classification': 0.81,
+ 'chinese_text': 0.79
+ },
+ 'cost_per_1k': 0.00015,
+ 'max_tokens': 4000,
+ 'avg_latency': 1.2,
+ 'requires_api_key': 'OPENAI_API_KEY'
+ },
+ 'gpt-3.5-turbo': {
+ 'provider': 'openai',
+ 'capabilities': {
+ 'text_analysis': 0.75,
+ 'sentiment_analysis': 0.72,
+ 'keyword_extraction': 0.70,
+ 'summarization': 0.77,
+ 'classification': 0.73,
+ 'chinese_text': 0.65
+ },
+ 'cost_per_1k': 0.0015,
+ 'max_tokens': 16000,
+ 'avg_latency': 0.8,
+ 'requires_api_key': 'OPENAI_API_KEY'
+ },
+
+ # Claude 模型
+ 'claude-3.5-sonnet': {
+ 'provider': 'anthropic',
+ 'capabilities': {
+ 'text_analysis': 0.90,
+ 'sentiment_analysis': 0.91,
+ 'keyword_extraction': 0.85,
+ 'summarization': 0.92,
+ 'classification': 0.89,
+ 'chinese_text': 0.80
+ },
+ 'cost_per_1k': 0.015,
+ 'max_tokens': 200000,
+ 'avg_latency': 2.8,
+ 'requires_api_key': 'ANTHROPIC_API_KEY'
+ },
+ 'claude-3.5-haiku': {
+ 'provider': 'anthropic',
+ 'capabilities': {
+ 'text_analysis': 0.84,
+ 'sentiment_analysis': 0.83,
+ 'keyword_extraction': 0.79,
+ 'summarization': 0.85,
+ 'classification': 0.80,
+ 'chinese_text': 0.72
+ },
+ 'cost_per_1k': 0.0025,
+ 'max_tokens': 200000,
+ 'avg_latency': 1.5,
+ 'requires_api_key': 'ANTHROPIC_API_KEY'
+ },
+
+ # DeepSeek 模型
+ 'deepseek-chat': {
+ 'provider': 'deepseek',
+ 'capabilities': {
+ 'text_analysis': 0.82,
+ 'sentiment_analysis': 0.79,
+ 'keyword_extraction': 0.77,
+ 'summarization': 0.80,
+ 'classification': 0.77,
+ 'chinese_text': 0.90 # 特别好中文
+ },
+ 'cost_per_1k': 0.002,
+ 'max_tokens': 4000,
+ 'avg_latency': 1.0,
+ 'requires_api_key': 'DEEPSEEK_API_KEY'
+ },
+ 'deepseek-reasoner': {
+ 'provider': 'deepseek',
+ 'capabilities': {
+ 'text_analysis': 0.87,
+ 'sentiment_analysis': 0.75,
+ 'keyword_extraction': 0.76,
+ 'summarization': 0.78,
+ 'classification': 0.85,
+ 'chinese_text': 0.88
+ },
+ 'cost_per_1k': 0.003,
+ 'max_tokens': 4000,
+ 'avg_latency': 1.8,
+ 'requires_api_key': 'DEEPSEEK_API_KEY'
+ }
+ }
+
+ # 任务类型定义
+ self.task_types = {
+ 'sentiment_analysis': {
+ 'description': '情感分析',
+ 'key_capabilities': ['sentiment_analysis', 'text_analysis'],
+ 'example_prompt': '分析以下文本的情感倾向(积极、消极或中性)'
+ },
+ 'topic_classification': {
+ 'description': '话题分类',
+ 'key_capabilities': ['classification', 'text_analysis'],
+ 'example_prompt': '将以下文本分类到最合适的话题类别'
+ },
+ 'keyword_extraction': {
+ 'description': '关键词提取',
+ 'key_capabilities': ['keyword_extraction', 'text_analysis'],
+ 'example_prompt': '从以下文本中提取5个最重要的关键词'
+ },
+ 'text_summarization': {
+ 'description': '文本摘要',
+ 'key_capabilities': ['summarization', 'text_analysis'],
+ 'example_prompt': '为以下文本生成一个简短的摘要'
+ },
+ 'comprehensive_analysis': {
+ 'description': '综合分析',
+ 'key_capabilities': ['text_analysis', 'sentiment_analysis', 'keyword_extraction', 'summarization'],
+ 'example_prompt': '对以下文本进行全面分析,包括情感、关键词和主要观点'
+ }
+ }
+
+ # 用户偏好和使用历史
+ self.usage_history = defaultdict(list)
+
+ # 模型可用性缓存
+ self.available_models = {}
+
+ # 更新模型可用性
+ self._update_available_models()
+
+ self._initialized = True
+ logger.info("模型路由器初始化完成")
+
+ def _update_available_models(self):
+ """更新模型可用性"""
+ self.available_models = {}
+
+ for model_id, model_info in self.models.items():
+ # 检查API密钥是否可用
+ api_key_env = model_info.get('requires_api_key')
+ if api_key_env and os.getenv(api_key_env):
+ self.available_models[model_id] = model_info
+
+ if not self.available_models:
+ logger.warning("未找到可用的模型,请检查API密钥配置")
+ else:
+ logger.info(f"找到 {len(self.available_models)} 个可用模型")
+
+ def detect_content_type(self, text: str) -> Dict[str, float]:
+ """
+ 检测内容类型和特征
+
+ 参数:
+ text: 要分析的文本
+
+ 返回:
+ 内容类型特征字典,键为特征名称,值为权重
+ """
+ features = {
+ 'chinese_text': 0.0,
+ 'length': 0.0,
+ 'complexity': 0.0
+ }
+
+ if not text:
+ return features
+
+ # 检测中文比例
+ chinese_chars = len(re.findall(r'[\u4e00-\u9fff]', text))
+ total_chars = len(text)
+ chinese_ratio = chinese_chars / total_chars if total_chars > 0 else 0
+
+ # 文本长度评分 (归一化至0-1)
+ length_score = min(1.0, len(text) / 10000)
+
+ # 文本复杂度简单估计
+ # 基于句子长度、词汇多样性等
+ sentences = re.split(r'[.!?。!?]', text)
+ avg_sentence_len = sum(len(s) for s in sentences) / len(sentences) if sentences else 0
+ unique_words = len(set(re.findall(r'\w+', text.lower())))
+ total_words = len(re.findall(r'\w+', text.lower()))
+
+ lexical_diversity = unique_words / total_words if total_words > 0 else 0
+ complexity_score = (avg_sentence_len / 50 + lexical_diversity) / 2
+ complexity_score = min(1.0, complexity_score)
+
+ features['chinese_text'] = chinese_ratio
+ features['length'] = length_score
+ features['complexity'] = complexity_score
+
+ return features
+
+ def select_model(self, text: str, task_type: str,
+ optimize_for: str = 'balanced',
+ exclude_models: List[str] = None) -> str:
+ """
+ 为给定文本和任务选择最合适的模型
+
+ 参数:
+ text: 要处理的文本
+ task_type: 任务类型,如 'sentiment_analysis'
+ optimize_for: 优化目标,可选值:'cost'(成本), 'performance'(性能), 'balanced'(平衡)
+ exclude_models: 要排除的模型列表
+
+ 返回:
+ 选择的模型ID
+ """
+ if not self.available_models:
+ logger.error("没有可用的模型,请检查API密钥配置")
+ return None
+
+ if task_type not in self.task_types:
+ logger.warning(f"未知的任务类型: {task_type},使用默认任务类型: 'comprehensive_analysis'")
+ task_type = 'comprehensive_analysis'
+
+ # 获取内容特征
+ content_features = self.detect_content_type(text)
+
+ # 获取任务关键能力
+ task_capabilities = self.task_types[task_type]['key_capabilities']
+
+ # 计算每个模型的得分
+ model_scores = {}
+ exclude_models = exclude_models or []
+
+ for model_id, model_info in self.available_models.items():
+ if model_id in exclude_models:
+ continue
+
+ # 基于任务能力的得分
+ capability_score = 0
+ for capability in task_capabilities:
+ capability_score += model_info['capabilities'].get(capability, 0)
+
+ capability_score /= len(task_capabilities)
+
+ # 基于内容特征的得分调整
+ content_score = 1.0
+
+ # 如果有大量中文,增加中文能力的权重
+ if content_features['chinese_text'] > 0.5:
+ chinese_capability = model_info['capabilities'].get('chinese_text', 0)
+ content_score *= (1.0 + chinese_capability) / 2
+
+ # 如果文本很长,检查模型的最大token限制
+ if content_features['length'] > 0.7:
+ max_tokens = model_info.get('max_tokens', 4000)
+ if max_tokens < 10000:
+ content_score *= 0.7 # 长文本降低短上下文模型的分数
+
+ # 如果文本很复杂,可能需要更强大的模型
+ if content_features['complexity'] > 0.7:
+ # 假设能力得分更高的模型更能处理复杂文本
+ content_score *= (1.0 + capability_score) / 2
+
+ # 根据优化目标调整最终得分
+ final_score = capability_score * content_score
+
+ if optimize_for == 'cost':
+ # 成本越低,分数越高
+ cost_factor = 1 - min(1.0, model_info.get('cost_per_1k', 0) / 0.03)
+ final_score = final_score * 0.3 + cost_factor * 0.7
+ elif optimize_for == 'performance':
+ # 能力得分权重更高
+ final_score = capability_score * 0.8 + content_score * 0.2
+ # balanced 是默认值,不需要额外调整
+
+ model_scores[model_id] = final_score
+
+ if not model_scores:
+ logger.warning("没有符合条件的可用模型")
+ return list(self.available_models.keys())[0]
+
+ # 选择得分最高的模型
+ selected_model = max(model_scores, key=model_scores.get)
+
+ # 记录使用历史
+ self.usage_history[task_type].append({
+ 'model': selected_model,
+ 'timestamp': datetime.now().timestamp(),
+ 'score': model_scores[selected_model],
+ 'optimize_for': optimize_for
+ })
+
+ logger.info(f"为任务 '{task_type}' 选择了模型: {selected_model} (得分: {model_scores[selected_model]:.4f})")
+ return selected_model
+
+ def get_model_info(self, model_id: str) -> Dict:
+ """获取模型信息"""
+ if model_id in self.models:
+ return self.models[model_id]
+ return None
+
+ def get_available_models(self, refresh: bool = False) -> Dict[str, Dict]:
+ """获取所有可用的模型"""
+ if refresh:
+ self._update_available_models()
+ return self.available_models
+
+ def get_model_by_provider(self, provider: str, optimize_for: str = 'balanced') -> str:
+ """根据提供商获取推荐模型"""
+ provider_models = {
+ model_id: info for model_id, info in self.available_models.items()
+ if info['provider'] == provider
+ }
+
+ if not provider_models:
+ logger.warning(f"未找到提供商 '{provider}' 的可用模型")
+ return None
+
+ if optimize_for == 'cost':
+ # 选择成本最低的模型
+ return min(provider_models.items(), key=lambda x: x[1].get('cost_per_1k', float('inf')))[0]
+ elif optimize_for == 'performance':
+ # 选择性能最好的模型,简单取所有能力的平均值
+ return max(provider_models.items(),
+ key=lambda x: sum(x[1]['capabilities'].values()) / len(x[1]['capabilities']))[0]
+ else:
+ # 平衡模式,综合考虑成本和性能
+ scores = {}
+ for model_id, info in provider_models.items():
+ perf_score = sum(info['capabilities'].values()) / len(info['capabilities'])
+ cost_score = 1 - min(1.0, info.get('cost_per_1k', 0) / 0.03)
+ scores[model_id] = perf_score * 0.5 + cost_score * 0.5
+
+ return max(scores, key=scores.get)
+
+ def get_task_types(self) -> Dict[str, Dict]:
+ """获取支持的任务类型"""
+ return self.task_types
+
+ def register_custom_model(self, model_id: str, model_info: Dict[str, Any]) -> bool:
+ """
+ 注册自定义模型
+
+ 参数:
+ model_id: 模型唯一标识符
+ model_info: 模型信息字典,包含以下字段:
+ - provider: 提供商名称
+ - capabilities: 能力评分字典
+ - cost_per_1k: 每千token的成本
+ - max_tokens: 最大token限制
+ - avg_latency: 平均延迟(秒)
+ - requires_api_key: API密钥环境变量名
+
+ 返回:
+ 是否注册成功
+ """
+ # 验证必要字段
+ required_fields = ['provider', 'capabilities', 'cost_per_1k', 'max_tokens']
+ for field in required_fields:
+ if field not in model_info:
+ logger.error(f"注册模型失败: 缺少必要字段 '{field}'")
+ return False
+
+ # 验证能力评分
+ if not isinstance(model_info['capabilities'], dict):
+ logger.error("注册模型失败: 'capabilities' 必须是字典")
+ return False
+
+ # 添加模型
+ self.models[model_id] = model_info
+
+ # 更新可用模型列表
+ self._update_available_models()
+
+ logger.info(f"成功注册自定义模型: {model_id}")
+ return True
+
+# 创建全局模型路由器实例
+model_router = ModelRouter()
+
+def select_model(text, task_type, optimize_for='balanced', exclude_models=None):
+ """选择最合适的模型"""
+ return model_router.select_model(text, task_type, optimize_for, exclude_models)
+
+def get_available_models(refresh=False):
+ """获取所有可用的模型"""
+ return model_router.get_available_models(refresh)
+
+def get_model_by_provider(provider, optimize_for='balanced'):
+ """根据提供商获取推荐模型"""
+ return model_router.get_model_by_provider(provider, optimize_for)
+
+def get_task_types():
+ """获取支持的任务类型"""
+ return model_router.get_task_types()
+
+def register_custom_model(model_id, model_info):
+ """注册自定义模型"""
+ return model_router.register_custom_model(model_id, model_info)
+
+# 示例用法
+if __name__ == "__main__":
+ # 示例文本
+ chinese_text = """
+ 近日,人工智能技术的发展引发广泛关注。
+ 专家指出,大型语言模型在自然语言处理领域取得了显著进展,
+ 但同时也带来了诸多伦理和安全问题。对此,业界呼吁加强监管,
+ 确保人工智能的发展能够造福人类社会。
+ """
+
+ english_text = """
+ Recent developments in artificial intelligence technology have drawn widespread attention.
+ Experts point out that large language models have made significant progress in the field of natural language processing,
+ but also bring many ethical and security issues. In response, the industry calls for strengthened regulation
+ to ensure that the development of artificial intelligence can benefit human society.
+ """
+
+ # 测试模型选择
+ print("中文文本任务测试:")
+ model_for_chinese = select_model(chinese_text, 'sentiment_analysis')
+ print(f"选择的模型: {model_for_chinese}")
+
+ print("\n英文文本任务测试:")
+ model_for_english = select_model(english_text, 'sentiment_analysis')
+ print(f"选择的模型: {model_for_english}")
+
+ print("\n成本优化测试:")
+ model_for_cost = select_model(chinese_text, 'sentiment_analysis', optimize_for='cost')
+ print(f"选择的模型: {model_for_cost}")
+
+ print("\n性能优化测试:")
+ model_for_perf = select_model(chinese_text, 'sentiment_analysis', optimize_for='performance')
+ print(f"选择的模型: {model_for_perf}")
+
+ # 测试API提供商
+ print("\n根据提供商获取模型:")
+ for provider in ['openai', 'anthropic', 'deepseek']:
+ model = get_model_by_provider(provider)
+ if model:
+ print(f"{provider}: {model}")
+ else:
+ print(f"{provider}: 无可用模型")
\ No newline at end of file
diff --git a/utils/sensitive_filter.py b/utils/sensitive_filter.py
new file mode 100644
index 0000000..e17f7e3
--- /dev/null
+++ b/utils/sensitive_filter.py
@@ -0,0 +1,357 @@
+import re
+import json
+import os
+import logging
+from pathlib import Path
+
+logger = logging.getLogger('sensitive_filter')
+logger.setLevel(logging.INFO)
+
+class SensitiveDataFilter:
+ """
+ 敏感数据过滤器 - 用于检测和屏蔽输出内容中的敏感信息
+
+ 功能:
+ 1. 自动识别并过滤手机号、邮箱、身份证号、信用卡号等敏感信息
+ 2. 支持自定义敏感信息模式和替换文本
+ 3. 提供批量处理和实时过滤功能
+ """
+
+ _instance = None
+
+ def __new__(cls):
+ if cls._instance is None:
+ cls._instance = super(SensitiveDataFilter, cls).__new__(cls)
+ cls._instance._initialized = False
+ return cls._instance
+
+ def __init__(self):
+ if self._initialized:
+ return
+
+ # 默认配置
+ self.config = {
+ 'enabled': os.getenv('ENABLE_SENSITIVE_DATA_FILTER', 'true').lower() == 'true',
+ 'patterns': {
+ 'phone': r'\b1[3-9]\d{9}\b',
+ 'email': r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b',
+ 'id_card': r'\b[1-9]\d{5}(19|20)\d{2}(0[1-9]|1[0-2])(0[1-9]|[12]\d|3[01])\d{3}[\dXx]\b',
+ 'credit_card': r'\b\d{4}[ -]?\d{4}[ -]?\d{4}[ -]?\d{4}\b',
+ 'address': r'(北京|上海|广州|深圳|天津|重庆|南京|杭州|武汉|成都|西安)市.*?(路|街|道|巷).*?(号)'
+ },
+ 'replacements': {
+ 'phone': '***********',
+ 'email': '******@*****',
+ 'id_card': '******************',
+ 'credit_card': '****************',
+ 'address': '[地址已隐藏]'
+ }
+ }
+
+ # 加载自定义配置
+ self._load_config()
+
+ # 编译正则表达式
+ self._compile_patterns()
+
+ self._initialized = True
+
+ logger.info("敏感数据过滤器初始化完成")
+ if self.config['enabled']:
+ logger.info(f"已启用以下类型的敏感数据过滤: {', '.join(self.config['patterns'].keys())}")
+ else:
+ logger.info("敏感数据过滤器已禁用")
+
+ def _load_config(self):
+ """加载自定义配置"""
+ # 配置文件路径
+ data_dir = os.getenv('DATA_DIR', 'data')
+ config_path = os.path.join(data_dir, 'security', 'sensitive_filter.json')
+
+ if os.path.exists(config_path):
+ try:
+ with open(config_path, 'r', encoding='utf-8') as f:
+ custom_config = json.load(f)
+
+ # 更新配置
+ if 'enabled' in custom_config:
+ self.config['enabled'] = custom_config['enabled']
+
+ if 'patterns' in custom_config:
+ for key, pattern in custom_config['patterns'].items():
+ self.config['patterns'][key] = pattern
+
+ if 'replacements' in custom_config:
+ for key, replacement in custom_config['replacements'].items():
+ self.config['replacements'][key] = replacement
+
+ logger.info(f"已加载自定义敏感数据过滤配置: {config_path}")
+ except Exception as e:
+ logger.error(f"加载敏感数据过滤配置失败: {e}")
+
+ def _compile_patterns(self):
+ """编译正则表达式"""
+ self.compiled_patterns = {}
+ for key, pattern in self.config['patterns'].items():
+ try:
+ self.compiled_patterns[key] = re.compile(pattern)
+ logger.debug(f"已编译敏感数据模式: {key} - {pattern}")
+ except re.error as e:
+ logger.error(f"编译敏感数据模式失败: {key} - {pattern}: {e}")
+
+ def filter_text(self, text):
+ """
+ 过滤文本中的敏感信息
+
+ 参数:
+ text: 要过滤的文本
+
+ 返回:
+ 过滤后的文本
+ """
+ if not self.config['enabled'] or not text:
+ return text
+
+ filtered_text = text
+ for key, pattern in self.compiled_patterns.items():
+ replacement = self.config['replacements'].get(key, '[FILTERED]')
+ filtered_text = pattern.sub(replacement, filtered_text)
+
+ return filtered_text
+
+ def filter_dict(self, data, *skip_keys):
+ """
+ 过滤字典中的敏感信息
+
+ 参数:
+ data: 要过滤的字典
+ skip_keys: 要跳过的键(不进行过滤)
+
+ 返回:
+ 过滤后的字典
+ """
+ if not self.config['enabled'] or not data:
+ return data
+
+ if not isinstance(data, dict):
+ if isinstance(data, str):
+ return self.filter_text(data)
+ return data
+
+ filtered_data = {}
+ for key, value in data.items():
+ if key in skip_keys:
+ filtered_data[key] = value
+ continue
+
+ if isinstance(value, dict):
+ filtered_data[key] = self.filter_dict(value, *skip_keys)
+ elif isinstance(value, list):
+ filtered_data[key] = [
+ self.filter_dict(item, *skip_keys) if isinstance(item, (dict, list)) else
+ self.filter_text(item) if isinstance(item, str) else item
+ for item in value
+ ]
+ elif isinstance(value, str):
+ filtered_data[key] = self.filter_text(value)
+ else:
+ filtered_data[key] = value
+
+ return filtered_data
+
+ def filter_list(self, data, *skip_keys):
+ """
+ 过滤列表中的敏感信息
+
+ 参数:
+ data: 要过滤的列表
+ skip_keys: 如果列表项是字典,要跳过的键
+
+ 返回:
+ 过滤后的列表
+ """
+ if not self.config['enabled'] or not data:
+ return data
+
+ if not isinstance(data, list):
+ if isinstance(data, dict):
+ return self.filter_dict(data, *skip_keys)
+ if isinstance(data, str):
+ return self.filter_text(data)
+ return data
+
+ return [
+ self.filter_dict(item, *skip_keys) if isinstance(item, dict) else
+ self.filter_list(item, *skip_keys) if isinstance(item, list) else
+ self.filter_text(item) if isinstance(item, str) else item
+ for item in data
+ ]
+
+ def is_sensitive_info(self, text, info_type=None):
+ """
+ 检查文本是否包含敏感信息
+
+ 参数:
+ text: 要检查的文本
+ info_type: 指定要检查的敏感信息类型,如果为None则检查所有类型
+
+ 返回:
+ 包含敏感信息返回True,否则返回False
+ """
+ if not self.config['enabled'] or not text:
+ return False
+
+ if info_type:
+ if info_type not in self.compiled_patterns:
+ logger.warning(f"未知的敏感信息类型: {info_type}")
+ return False
+ return bool(self.compiled_patterns[info_type].search(text))
+
+ for pattern in self.compiled_patterns.values():
+ if pattern.search(text):
+ return True
+
+ return False
+
+ def get_sensitive_info_types(self, text):
+ """
+ 获取文本中包含的敏感信息类型
+
+ 参数:
+ text: 要检查的文本
+
+ 返回:
+ 包含的敏感信息类型列表
+ """
+ if not self.config['enabled'] or not text:
+ return []
+
+ types = []
+ for key, pattern in self.compiled_patterns.items():
+ if pattern.search(text):
+ types.append(key)
+
+ return types
+
+ def enable(self):
+ """启用敏感数据过滤器"""
+ self.config['enabled'] = True
+ logger.info("敏感数据过滤器已启用")
+
+ def disable(self):
+ """禁用敏感数据过滤器"""
+ self.config['enabled'] = False
+ logger.info("敏感数据过滤器已禁用")
+
+ def is_enabled(self):
+ """检查敏感数据过滤器是否启用"""
+ return self.config['enabled']
+
+ def add_pattern(self, key, pattern, replacement='[FILTERED]'):
+ """
+ 添加自定义敏感信息模式
+
+ 参数:
+ key: 敏感信息类型标识
+ pattern: 正则表达式字符串
+ replacement: 替换文本
+ """
+ try:
+ # 测试是否是有效的正则表达式
+ re.compile(pattern)
+
+ # 更新配置
+ self.config['patterns'][key] = pattern
+ self.config['replacements'][key] = replacement
+
+ # 重新编译正则表达式
+ self._compile_patterns()
+
+ logger.info(f"已添加敏感信息模式: {key}")
+ return True
+ except re.error as e:
+ logger.error(f"添加敏感信息模式失败: {key} - {pattern}: {e}")
+ return False
+
+ def remove_pattern(self, key):
+ """
+ 移除敏感信息模式
+
+ 参数:
+ key: 敏感信息类型标识
+ """
+ if key in self.config['patterns']:
+ del self.config['patterns'][key]
+
+ if key in self.config['replacements']:
+ del self.config['replacements'][key]
+
+ if key in self.compiled_patterns:
+ del self.compiled_patterns[key]
+
+ logger.info(f"已移除敏感信息模式: {key}")
+ return True
+
+ logger.warning(f"未找到敏感信息模式: {key}")
+ return False
+
+ def save_config(self):
+ """保存当前配置到文件"""
+ data_dir = os.getenv('DATA_DIR', 'data')
+ security_dir = os.path.join(data_dir, 'security')
+ os.makedirs(security_dir, exist_ok=True)
+
+ config_path = os.path.join(security_dir, 'sensitive_filter.json')
+
+ try:
+ with open(config_path, 'w', encoding='utf-8') as f:
+ json.dump(self.config, f, ensure_ascii=False, indent=2)
+
+ logger.info(f"敏感数据过滤配置已保存到: {config_path}")
+ return True
+ except Exception as e:
+ logger.error(f"保存敏感数据过滤配置失败: {e}")
+ return False
+
+# 创建全局敏感数据过滤器实例
+sensitive_filter = SensitiveDataFilter()
+
+# 提供便捷的过滤函数
+def filter_text(text):
+ """过滤文本中的敏感信息"""
+ return sensitive_filter.filter_text(text)
+
+def filter_dict(data, *skip_keys):
+ """过滤字典中的敏感信息"""
+ return sensitive_filter.filter_dict(data, *skip_keys)
+
+def filter_list(data, *skip_keys):
+ """过滤列表中的敏感信息"""
+ return sensitive_filter.filter_list(data, *skip_keys)
+
+def is_sensitive_info(text, info_type=None):
+ """检查文本是否包含敏感信息"""
+ return sensitive_filter.is_sensitive_info(text, info_type)
+
+# 示例用法
+if __name__ == "__main__":
+ # 测试文本
+ test_text = """
+ 联系人: 张三
+ 电话: 13812345678
+ 邮箱: zhangsan@example.com
+ 身份证: 110101199001011234
+ 地址: 北京市海淀区中关村大街20号
+ 信用卡: 6225 1234 5678 9012
+ """
+
+ # 过滤敏感信息
+ filtered_text = filter_text(test_text)
+ print("原始文本:")
+ print(test_text)
+ print("\n过滤后:")
+ print(filtered_text)
+
+ # 检查敏感信息类型
+ types = sensitive_filter.get_sensitive_info_types(test_text)
+ print(f"\n包含的敏感信息类型: {types}")
\ No newline at end of file
diff --git a/utils/workflow_engine.py b/utils/workflow_engine.py
new file mode 100644
index 0000000..a8ec467
--- /dev/null
+++ b/utils/workflow_engine.py
@@ -0,0 +1,837 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+import os
+import json
+import time
+import uuid
+import logging
+import traceback
+from datetime import datetime
+from pathlib import Path
+from concurrent.futures import ThreadPoolExecutor
+
+from utils.db_manager import DatabaseManager
+from utils.cache_manager import CacheManager
+from utils.model_router import ModelRouter
+from utils.sensitive_filter import SensitiveDataFilter
+from spider.weibo_crawler import WeiboCrawler
+from utils.ai_analyzer import AIAnalyzer
+
+# 配置日志
+from utils.logger import setup_logger
+logger = setup_logger('workflow_engine', 'logs/workflow_engine.log')
+
+class WorkflowEngine:
+ """工作流引擎 - 负责执行数据爬取和分析工作流"""
+
+ _instance = None
+ _initialized = False
+
+ def __new__(cls):
+ if cls._instance is None:
+ cls._instance = super(WorkflowEngine, cls).__new__(cls)
+ return cls._instance
+
+ def __init__(self):
+ if self._initialized:
+ return
+
+ self.db = DatabaseManager()
+ self.cache = CacheManager(memory_capacity=50, cache_duration=3600)
+ self.model_router = ModelRouter()
+ self.sensitive_filter = SensitiveDataFilter()
+ self.executor = ThreadPoolExecutor(max_workers=5)
+ self.running_tasks = {}
+
+ # 创建必要的目录
+ self.data_dir = Path('data/workflow')
+ self.data_dir.mkdir(parents=True, exist_ok=True)
+
+ self._initialized = True
+ logger.info("工作流引擎初始化完成")
+
+ def execute_crawler_workflow(self, task_id, config):
+ """
+ 执行爬虫工作流
+
+ Args:
+ task_id: 任务ID
+ config: 爬虫配置
+ """
+ logger.info(f"开始执行爬虫工作流: {task_id}")
+
+ try:
+ # 更新任务状态为运行中
+ self._update_task_status(task_id, 'running', 0)
+
+ # 创建爬虫实例
+ crawler = WeiboCrawler()
+
+ # 设置爬虫参数
+ source = config.get('source', 'hot_topics')
+ depth = config.get('crawl_depth', 1)
+ interval = config.get('interval', 5)
+ filters = config.get('filters', {})
+
+ # 执行爬取
+ result = crawler.crawl(
+ source=source,
+ depth=depth,
+ interval=interval,
+ filters=filters,
+ callback=lambda progress: self._update_task_progress(task_id, progress)
+ )
+
+ # 更新任务状态为已完成
+ self._update_task_status(task_id, 'completed', 100, result=result)
+ logger.info(f"爬虫工作流完成: {task_id}")
+
+ return result
+
+ except Exception as e:
+ logger.error(f"爬虫工作流出错: {str(e)}")
+ logger.error(traceback.format_exc())
+ self._update_task_status(task_id, 'failed', 0, error=str(e))
+ return None
+
+ def execute_analysis_workflow(self, task_id, workflow):
+ """
+ 执行分析工作流
+
+ Args:
+ task_id: 任务ID
+ workflow: 工作流配置
+ """
+ logger.info(f"开始执行分析工作流: {task_id}")
+
+ try:
+ # 更新任务状态为运行中
+ self._update_task_status(task_id, 'running', 0)
+
+ components = workflow.get('components', [])
+ connections = workflow.get('connections', [])
+
+ # 验证工作流
+ if not components or not connections:
+ raise ValueError("工作流配置不完整,缺少组件或连接")
+
+ # 构建组件依赖图
+ component_map, dependency_graph = self._build_dependency_graph(components, connections)
+
+ # 进行拓扑排序
+ execution_order = self._topological_sort(dependency_graph)
+
+ # 执行组件
+ result_map = {}
+ total_components = len(execution_order)
+
+ for idx, component_id in enumerate(execution_order):
+ component = component_map.get(component_id)
+ if not component:
+ continue
+
+ # 计算总体进度
+ progress = int((idx / total_components) * 100)
+ self._update_task_progress(task_id, progress)
+
+ # 获取输入数据
+ input_data = self._get_component_input_data(component_id, connections, result_map)
+
+ # 执行组件
+ result = self._execute_component(component, input_data)
+
+ # 存储结果
+ result_map[component_id] = result
+
+ # 获取最终输出
+ final_outputs = self._get_final_outputs(dependency_graph, result_map)
+
+ # 应用敏感信息过滤
+ if final_outputs and self.sensitive_filter.is_enabled():
+ if isinstance(final_outputs, dict):
+ final_outputs = self.sensitive_filter.filter_dict(final_outputs)
+ elif isinstance(final_outputs, list):
+ final_outputs = self.sensitive_filter.filter_list(final_outputs)
+
+ # 更新任务状态为已完成
+ self._update_task_status(task_id, 'completed', 100, result=final_outputs)
+ logger.info(f"分析工作流完成: {task_id}")
+
+ return final_outputs
+
+ except Exception as e:
+ logger.error(f"分析工作流出错: {str(e)}")
+ logger.error(traceback.format_exc())
+ self._update_task_status(task_id, 'failed', 0, error=str(e))
+ return None
+
+ def start_workflow(self, workflow_type, config, template_id=None):
+ """
+ 异步启动工作流
+
+ Args:
+ workflow_type: 工作流类型 (crawler/analysis)
+ config: 工作流配置
+ template_id: 关联的模板ID
+
+ Returns:
+ task_id: 工作流任务ID
+ """
+ # 生成任务ID
+ task_id = str(uuid.uuid4())
+
+ # 保存任务信息到数据库
+ conn = self.db.get_connection()
+ cursor = conn.cursor()
+
+ try:
+ now = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
+ cursor.execute(
+ """
+ INSERT INTO workflow_tasks
+ (id, template_id, type, status, progress, config, created_at, updated_at)
+ VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
+ """,
+ (
+ task_id,
+ template_id,
+ workflow_type,
+ 'pending',
+ 0,
+ json.dumps(config, ensure_ascii=False),
+ now,
+ now
+ )
+ )
+ conn.commit()
+
+ # 异步执行工作流
+ if workflow_type == 'crawler':
+ self.running_tasks[task_id] = self.executor.submit(
+ self.execute_crawler_workflow, task_id, config
+ )
+ elif workflow_type == 'analysis':
+ self.running_tasks[task_id] = self.executor.submit(
+ self.execute_analysis_workflow, task_id, config
+ )
+ else:
+ logger.error(f"未知的工作流类型: {workflow_type}")
+ return None
+
+ return task_id
+
+ except Exception as e:
+ logger.error(f"启动工作流失败: {str(e)}")
+ conn.rollback()
+ return None
+ finally:
+ cursor.close()
+
+ def get_task_status(self, task_id):
+ """
+ 获取任务状态
+
+ Args:
+ task_id: 任务ID
+
+ Returns:
+ task: 任务信息
+ """
+ # 先检查缓存
+ cache_key = f"task_status:{task_id}"
+ cached_task = self.cache.get(cache_key)
+ if cached_task:
+ return cached_task
+
+ # 从数据库获取
+ conn = self.db.get_connection()
+ cursor = conn.cursor()
+
+ try:
+ cursor.execute(
+ "SELECT * FROM workflow_tasks WHERE id = %s",
+ (task_id,)
+ )
+ task = cursor.fetchone()
+
+ if task:
+ # 将JSON字符串转为Python对象
+ if task.get('config'):
+ task['config'] = json.loads(task['config'])
+ if task.get('result'):
+ task['result'] = json.loads(task['result'])
+
+ # 缓存结果
+ self.cache.set(cache_key, task)
+
+ return task
+
+ except Exception as e:
+ logger.error(f"获取任务状态失败: {str(e)}")
+ return None
+ finally:
+ cursor.close()
+
+ def cancel_task(self, task_id):
+ """
+ 取消任务
+
+ Args:
+ task_id: 任务ID
+
+ Returns:
+ success: 是否成功
+ """
+ # 检查任务是否存在并正在运行
+ if task_id in self.running_tasks:
+ # 尝试取消任务
+ future = self.running_tasks[task_id]
+ if not future.done():
+ future.cancel()
+
+ # 从运行列表中移除
+ del self.running_tasks[task_id]
+
+ # 更新数据库状态
+ conn = self.db.get_connection()
+ cursor = conn.cursor()
+
+ try:
+ now = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
+ cursor.execute(
+ """
+ UPDATE workflow_tasks
+ SET status = %s, updated_at = %s
+ WHERE id = %s
+ """,
+ ('cancelled', now, task_id)
+ )
+ conn.commit()
+
+ # 清理缓存
+ cache_key = f"task_status:{task_id}"
+ self.cache.delete(cache_key)
+
+ return True
+
+ except Exception as e:
+ logger.error(f"取消任务失败: {str(e)}")
+ conn.rollback()
+ return False
+ finally:
+ cursor.close()
+
+ def _update_task_status(self, task_id, status, progress, result=None, error=None):
+ """更新任务状态"""
+ conn = self.db.get_connection()
+ cursor = conn.cursor()
+
+ try:
+ now = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
+ update_fields = ["status = %s", "progress = %s", "updated_at = %s"]
+ params = [status, progress, now]
+
+ # 添加开始时间
+ if status == 'running' and progress == 0:
+ update_fields.append("started_at = %s")
+ params.append(now)
+
+ # 添加完成时间
+ if status in ['completed', 'failed']:
+ update_fields.append("completed_at = %s")
+ params.append(now)
+
+ # 添加结果
+ if result is not None:
+ update_fields.append("result = %s")
+ params.append(json.dumps(result, ensure_ascii=False))
+
+ # 添加错误
+ if error is not None:
+ update_fields.append("error = %s")
+ params.append(error)
+
+ # 构建SQL
+ sql = f"""
+ UPDATE workflow_tasks
+ SET {', '.join(update_fields)}
+ WHERE id = %s
+ """
+ params.append(task_id)
+
+ cursor.execute(sql, tuple(params))
+ conn.commit()
+
+ # 清理缓存
+ cache_key = f"task_status:{task_id}"
+ self.cache.delete(cache_key)
+
+ except Exception as e:
+ logger.error(f"更新任务状态失败: {str(e)}")
+ conn.rollback()
+ finally:
+ cursor.close()
+
+ def _update_task_progress(self, task_id, progress):
+ """更新任务进度"""
+ self._update_task_status(task_id, 'running', progress)
+
+ def _build_dependency_graph(self, components, connections):
+ """构建组件依赖图"""
+ component_map = {comp['id']: comp for comp in components}
+ dependency_graph = {comp['id']: [] for comp in components}
+
+ # 构建依赖关系
+ for conn in connections:
+ source = conn.get('source')
+ target = conn.get('target')
+
+ if source and target and source in component_map and target in component_map:
+ dependency_graph[target].append(source)
+
+ return component_map, dependency_graph
+
+ def _topological_sort(self, graph):
+ """拓扑排序,确定组件执行顺序"""
+ visited = set()
+ temp = set()
+ order = []
+
+ def visit(node):
+ if node in temp:
+ raise ValueError(f"工作流存在循环依赖: {node}")
+ if node in visited:
+ return
+
+ temp.add(node)
+ for neighbor in graph.get(node, []):
+ visit(neighbor)
+
+ temp.remove(node)
+ visited.add(node)
+ order.append(node)
+
+ for node in graph:
+ if node not in visited:
+ visit(node)
+
+ return list(reversed(order))
+
+ def _get_component_input_data(self, component_id, connections, result_map):
+ """获取组件的输入数据"""
+ input_data = {}
+
+ for conn in connections:
+ if conn.get('target') == component_id:
+ source_id = conn.get('source')
+ if source_id in result_map:
+ input_name = conn.get('targetInput', 'default')
+ input_data[input_name] = result_map[source_id]
+
+ return input_data
+
+ def _execute_component(self, component, input_data):
+ """执行单个组件"""
+ component_type = component.get('type')
+ config = component.get('config', {})
+
+ if component_type == 'data_source':
+ return self._execute_data_source(config, input_data)
+ elif component_type == 'preprocessing':
+ return self._execute_preprocessing(config, input_data)
+ elif component_type == 'model':
+ return self._execute_model(config, input_data)
+ elif component_type == 'visualization':
+ return self._execute_visualization(config, input_data)
+ else:
+ logger.warning(f"未知的组件类型: {component_type}")
+ return None
+
+ def _execute_data_source(self, config, input_data):
+ """执行数据源组件"""
+ source_type = config.get('source_type')
+
+ if source_type == 'database':
+ # 从数据库获取数据
+ table = config.get('table')
+ filters = config.get('filters', {})
+ limit = config.get('limit', 1000)
+
+ query_conditions = []
+ query_params = []
+
+ for key, value in filters.items():
+ if value:
+ query_conditions.append(f"{key} = %s")
+ query_params.append(value)
+
+ where_clause = f"WHERE {' AND '.join(query_conditions)}" if query_conditions else ""
+
+ sql = f"SELECT * FROM {table} {where_clause} LIMIT {limit}"
+
+ conn = self.db.get_connection()
+ cursor = conn.cursor()
+
+ try:
+ cursor.execute(sql, tuple(query_params))
+ return cursor.fetchall()
+ except Exception as e:
+ logger.error(f"数据库查询出错: {str(e)}")
+ return []
+ finally:
+ cursor.close()
+
+ elif source_type == 'file':
+ # 从文件加载数据
+ file_path = config.get('file_path')
+ if not file_path or not os.path.exists(file_path):
+ return []
+
+ try:
+ with open(file_path, 'r', encoding='utf-8') as f:
+ if file_path.endswith('.json'):
+ return json.load(f)
+ else:
+ return f.read()
+ except Exception as e:
+ logger.error(f"文件读取出错: {str(e)}")
+ return []
+
+ elif source_type == 'api':
+ # 这里需要实现API调用逻辑
+ # 由于涉及复杂的HTTP请求,暂不实现
+ logger.warning("API数据源暂未实现")
+ return []
+
+ else:
+ logger.warning(f"未知的数据源类型: {source_type}")
+ return []
+
+ def _execute_preprocessing(self, config, input_data):
+ """执行数据预处理组件"""
+ preprocessing_type = config.get('preprocessing_type')
+ data = input_data.get('default', [])
+
+ if not data:
+ return []
+
+ if preprocessing_type == 'filter':
+ # 过滤数据
+ field = config.get('field')
+ value = config.get('value')
+ operator = config.get('operator', 'eq')
+
+ if not field:
+ return data
+
+ result = []
+ for item in data:
+ if operator == 'eq' and item.get(field) == value:
+ result.append(item)
+ elif operator == 'neq' and item.get(field) != value:
+ result.append(item)
+ elif operator == 'contains' and value in str(item.get(field, '')):
+ result.append(item)
+ elif operator == 'not_contains' and value not in str(item.get(field, '')):
+ result.append(item)
+
+ return result
+
+ elif preprocessing_type == 'sort':
+ # 排序数据
+ field = config.get('field')
+ order = config.get('order', 'asc')
+
+ if not field:
+ return data
+
+ return sorted(
+ data,
+ key=lambda x: x.get(field, ''),
+ reverse=(order == 'desc')
+ )
+
+ elif preprocessing_type == 'aggregate':
+ # 聚合数据
+ group_by = config.get('group_by')
+ aggregate_field = config.get('aggregate_field')
+ aggregate_type = config.get('aggregate_type', 'count')
+
+ if not group_by:
+ return data
+
+ result = {}
+ for item in data:
+ key = item.get(group_by)
+ if key not in result:
+ result[key] = {
+ 'count': 0,
+ 'sum': 0,
+ 'values': []
+ }
+
+ result[key]['count'] += 1
+
+ if aggregate_field:
+ value = item.get(aggregate_field, 0)
+ if isinstance(value, (int, float)):
+ result[key]['sum'] += value
+ result[key]['values'].append(value)
+
+ # 计算最终结果
+ final_result = []
+ for key, values in result.items():
+ item = {group_by: key}
+
+ if aggregate_type == 'count':
+ item['value'] = values['count']
+ elif aggregate_type == 'sum':
+ item['value'] = values['sum']
+ elif aggregate_type == 'avg':
+ item['value'] = values['sum'] / values['count'] if values['count'] > 0 else 0
+
+ final_result.append(item)
+
+ return final_result
+
+ else:
+ logger.warning(f"未知的预处理类型: {preprocessing_type}")
+ return data
+
+ def _execute_model(self, config, input_data):
+ """执行模型组件"""
+ model_type = config.get('model_type')
+ data = input_data.get('default', [])
+
+ if not data:
+ return []
+
+ analyzer = AIAnalyzer()
+
+ if model_type == 'sentiment':
+ # 情感分析
+ texts = []
+ if isinstance(data, list):
+ # 如果是列表,从指定字段获取文本
+ field = config.get('text_field', 'content')
+ texts = [item.get(field, '') for item in data if item.get(field)]
+ elif isinstance(data, str):
+ # 如果是字符串,直接使用
+ texts = [data]
+
+ # 获取合适的模型
+ model = self.model_router.select_model_for_text(texts[0] if texts else "", "sentiment")
+
+ # 执行分析
+ results = []
+ for text in texts:
+ result = analyzer.analyze_sentiment(text, model=model)
+ results.append(result)
+
+ # 如果输入是列表,将结果合并回原始数据
+ if isinstance(data, list):
+ field = config.get('text_field', 'content')
+ for i, item in enumerate(data):
+ if i < len(results) and item.get(field):
+ item['sentiment'] = results[i]
+ return data
+ else:
+ return results[0] if results else None
+
+ elif model_type == 'topic':
+ # 主题分类
+ texts = []
+ if isinstance(data, list):
+ field = config.get('text_field', 'content')
+ texts = [item.get(field, '') for item in data if item.get(field)]
+ elif isinstance(data, str):
+ texts = [data]
+
+ # 获取合适的模型
+ model = self.model_router.select_model_for_text(texts[0] if texts else "", "topic")
+
+ # 执行分析
+ results = []
+ for text in texts:
+ result = analyzer.analyze_topic(text, model=model)
+ results.append(result)
+
+ # 如果输入是列表,将结果合并回原始数据
+ if isinstance(data, list):
+ field = config.get('text_field', 'content')
+ for i, item in enumerate(data):
+ if i < len(results) and item.get(field):
+ item['topic'] = results[i]
+ return data
+ else:
+ return results[0] if results else None
+
+ elif model_type == 'keywords':
+ # 关键词提取
+ texts = []
+ if isinstance(data, list):
+ field = config.get('text_field', 'content')
+ texts = [item.get(field, '') for item in data if item.get(field)]
+ elif isinstance(data, str):
+ texts = [data]
+
+ # 获取合适的模型
+ model = self.model_router.select_model_for_text(texts[0] if texts else "", "keyword")
+
+ # 执行分析
+ results = []
+ for text in texts:
+ result = analyzer.extract_keywords(text, model=model)
+ results.append(result)
+
+ # 如果输入是列表,将结果合并回原始数据
+ if isinstance(data, list):
+ field = config.get('text_field', 'content')
+ for i, item in enumerate(data):
+ if i < len(results) and item.get(field):
+ item['keywords'] = results[i]
+ return data
+ else:
+ return results[0] if results else None
+
+ elif model_type == 'summarize':
+ # 文本摘要
+ texts = []
+ if isinstance(data, list):
+ field = config.get('text_field', 'content')
+ texts = [item.get(field, '') for item in data if item.get(field)]
+ elif isinstance(data, str):
+ texts = [data]
+
+ # 获取合适的模型
+ model = self.model_router.select_model_for_text(texts[0] if texts else "", "summarization")
+
+ # 执行分析
+ results = []
+ for text in texts:
+ result = analyzer.summarize_text(text, model=model)
+ results.append(result)
+
+ # 如果输入是列表,将结果合并回原始数据
+ if isinstance(data, list):
+ field = config.get('text_field', 'content')
+ for i, item in enumerate(data):
+ if i < len(results) and item.get(field):
+ item['summary'] = results[i]
+ return data
+ else:
+ return results[0] if results else None
+
+ else:
+ logger.warning(f"未知的模型类型: {model_type}")
+ return data
+
+ def _execute_visualization(self, config, input_data):
+ """执行可视化组件"""
+ visualization_type = config.get('visualization_type')
+ data = input_data.get('default', [])
+
+ if not data:
+ return {}
+
+ if visualization_type == 'chart':
+ # 图表可视化
+ chart_type = config.get('chart_type', 'bar')
+ x_field = config.get('x_field')
+ y_field = config.get('y_field')
+ title = config.get('title', '数据可视化')
+
+ if not x_field or not y_field:
+ return {'error': '缺少x或y字段'}
+
+ # 提取数据
+ chart_data = {
+ 'type': chart_type,
+ 'title': title,
+ 'xAxis': {'type': 'category', 'data': []},
+ 'yAxis': {'type': 'value'},
+ 'series': [{'data': []}]
+ }
+
+ for item in data:
+ x_value = item.get(x_field)
+ y_value = item.get(y_field)
+
+ if x_value is not None and y_value is not None:
+ chart_data['xAxis']['data'].append(x_value)
+ chart_data['series'][0]['data'].append(y_value)
+
+ return chart_data
+
+ elif visualization_type == 'table':
+ # 表格可视化
+ columns = config.get('columns', [])
+ title = config.get('title', '数据表格')
+
+ # 如果没有指定列,使用数据中的所有字段
+ if not columns and isinstance(data, list) and data:
+ columns = list(data[0].keys())
+
+ # 构建表格数据
+ table_data = {
+ 'type': 'table',
+ 'title': title,
+ 'columns': columns,
+ 'data': data
+ }
+
+ return table_data
+
+ elif visualization_type == 'wordcloud':
+ # 词云可视化
+ word_field = config.get('word_field')
+ value_field = config.get('value_field')
+ title = config.get('title', '词云图')
+
+ if not word_field:
+ return {'error': '缺少词字段'}
+
+ # 构建词云数据
+ wordcloud_data = {
+ 'type': 'wordcloud',
+ 'title': title,
+ 'data': []
+ }
+
+ for item in data:
+ word = item.get(word_field)
+ value = item.get(value_field, 1)
+
+ if word:
+ wordcloud_data['data'].append({
+ 'name': word,
+ 'value': value
+ })
+
+ return wordcloud_data
+
+ else:
+ logger.warning(f"未知的可视化类型: {visualization_type}")
+ return {}
+
+ def _get_final_outputs(self, dependency_graph, result_map):
+ """获取最终输出结果"""
+ # 找出没有后继节点的叶子节点
+ leaf_nodes = []
+ all_targets = set()
+
+ for node, deps in dependency_graph.items():
+ all_targets.update(deps)
+
+ for node in dependency_graph:
+ if node not in all_targets:
+ leaf_nodes.append(node)
+
+ # 收集所有叶子节点的结果
+ outputs = {}
+ for node in leaf_nodes:
+ if node in result_map:
+ outputs[node] = result_map[node]
+
+ return outputs
\ No newline at end of file
diff --git a/views/workflow_api.py b/views/workflow_api.py
new file mode 100644
index 0000000..ab3dd38
--- /dev/null
+++ b/views/workflow_api.py
@@ -0,0 +1,896 @@
+import os
+import json
+import time
+import uuid
+import logging
+from datetime import datetime, timedelta
+from flask import Blueprint, request, jsonify, current_app
+from utils.db_manager import DatabaseManager
+from utils.sensitive_filter import filter_dict
+from utils.cache_manager import CacheManager
+
+workflow_bp = Blueprint('workflow', __name__, url_prefix='/api/workflow')
+logger = logging.getLogger('workflow_api')
+logger.setLevel(logging.INFO)
+
+# 缓存管理器
+workflow_cache = CacheManager(name="workflows", memory_capacity=100, cache_duration=1)
+
+# 默认爬虫配置模板
+DEFAULT_CRAWLER_TEMPLATES = [
+ {
+ "id": "default_weibo",
+ "name": "微博热门话题",
+ "description": "抓取微博热门话题及相关评论",
+ "icon": "fab fa-weibo",
+ "config": {
+ "source": "weibo",
+ "crawlDepth": 2,
+ "interval": 3600,
+ "maxRetries": 3,
+ "timeout": 30,
+ "maxConcurrent": 2,
+ "userAgent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36",
+ "filters": {
+ "minComments": 10,
+ "minLikes": 50,
+ "excludeKeywords": []
+ }
+ }
+ },
+ {
+ "id": "weibo_trending",
+ "name": "微博热搜榜",
+ "description": "抓取微博热搜榜单内容",
+ "icon": "fas fa-fire",
+ "config": {
+ "source": "weibo_trending",
+ "crawlDepth": 1,
+ "interval": 1800,
+ "maxRetries": 3,
+ "timeout": 20,
+ "maxConcurrent": 1,
+ "filters": {
+ "topN": 50,
+ "excludeKeywords": []
+ }
+ }
+ }
+]
+
+# 默认分析流程模板
+DEFAULT_ANALYSIS_TEMPLATES = [
+ {
+ "id": "sentiment_analysis",
+ "name": "情感分析流程",
+ "description": "对文本进行情感分析",
+ "icon": "fas fa-smile",
+ "components": [
+ {
+ "id": "data_source",
+ "type": "data_source",
+ "name": "数据源",
+ "config": {
+ "source_type": "database",
+ "table": "comments",
+ "filter": {
+ "timeRange": "1d"
+ }
+ },
+ "position": {"x": 100, "y": 100}
+ },
+ {
+ "id": "text_preprocessing",
+ "type": "preprocessing",
+ "name": "文本预处理",
+ "config": {
+ "removeStopwords": True,
+ "removeURLs": True,
+ "removeEmojis": False
+ },
+ "position": {"x": 300, "y": 100}
+ },
+ {
+ "id": "sentiment_model",
+ "type": "model",
+ "name": "情感分析模型",
+ "config": {
+ "model_type": "sentiment",
+ "api": "openai",
+ "optimize_for": "balanced"
+ },
+ "position": {"x": 500, "y": 100}
+ },
+ {
+ "id": "visualization",
+ "type": "visualization",
+ "name": "可视化",
+ "config": {
+ "chart_type": "pie",
+ "title": "情感分布"
+ },
+ "position": {"x": 700, "y": 100}
+ }
+ ],
+ "connections": [
+ {"source": "data_source", "target": "text_preprocessing"},
+ {"source": "text_preprocessing", "target": "sentiment_model"},
+ {"source": "sentiment_model", "target": "visualization"}
+ ]
+ },
+ {
+ "id": "topic_analysis",
+ "name": "话题分析流程",
+ "description": "对文本进行话题分类和关键词提取",
+ "icon": "fas fa-tasks",
+ "components": [
+ {
+ "id": "data_source",
+ "type": "data_source",
+ "name": "数据源",
+ "config": {
+ "source_type": "database",
+ "table": "weibo_posts",
+ "filter": {
+ "timeRange": "7d"
+ }
+ },
+ "position": {"x": 100, "y": 100}
+ },
+ {
+ "id": "text_preprocessing",
+ "type": "preprocessing",
+ "name": "文本预处理",
+ "config": {
+ "removeStopwords": True,
+ "removeURLs": True,
+ "removeEmojis": True
+ },
+ "position": {"x": 300, "y": 100}
+ },
+ {
+ "id": "topic_model",
+ "type": "model",
+ "name": "话题分类模型",
+ "config": {
+ "model_type": "topic_classification",
+ "api": "deepseek",
+ "optimize_for": "performance"
+ },
+ "position": {"x": 500, "y": 50}
+ },
+ {
+ "id": "keyword_model",
+ "type": "model",
+ "name": "关键词提取模型",
+ "config": {
+ "model_type": "keyword_extraction",
+ "api": "openai",
+ "optimize_for": "balanced"
+ },
+ "position": {"x": 500, "y": 150}
+ },
+ {
+ "id": "topic_viz",
+ "type": "visualization",
+ "name": "话题分布",
+ "config": {
+ "chart_type": "bar",
+ "title": "话题分布"
+ },
+ "position": {"x": 700, "y": 50}
+ },
+ {
+ "id": "keyword_viz",
+ "type": "visualization",
+ "name": "关键词云",
+ "config": {
+ "chart_type": "wordcloud",
+ "title": "热门关键词"
+ },
+ "position": {"x": 700, "y": 150}
+ }
+ ],
+ "connections": [
+ {"source": "data_source", "target": "text_preprocessing"},
+ {"source": "text_preprocessing", "target": "topic_model"},
+ {"source": "text_preprocessing", "target": "keyword_model"},
+ {"source": "topic_model", "target": "topic_viz"},
+ {"source": "keyword_model", "target": "keyword_viz"}
+ ]
+ }
+]
+
+# 默认可用组件
+AVAILABLE_COMPONENTS = {
+ "data_source": [
+ {
+ "id": "database",
+ "name": "数据库",
+ "description": "从系统数据库获取数据",
+ "config_schema": {
+ "table": {"type": "string", "description": "数据表名", "required": True},
+ "filter": {"type": "object", "description": "数据过滤条件"}
+ }
+ },
+ {
+ "id": "api",
+ "name": "API接口",
+ "description": "从外部API获取数据",
+ "config_schema": {
+ "url": {"type": "string", "description": "API URL", "required": True},
+ "method": {"type": "string", "description": "请求方法", "default": "GET"},
+ "headers": {"type": "object", "description": "请求头"},
+ "params": {"type": "object", "description": "请求参数"}
+ }
+ },
+ {
+ "id": "csv",
+ "name": "CSV文件",
+ "description": "从CSV文件导入数据",
+ "config_schema": {
+ "file_path": {"type": "string", "description": "文件路径", "required": True},
+ "encoding": {"type": "string", "description": "文件编码", "default": "utf-8"},
+ "delimiter": {"type": "string", "description": "分隔符", "default": ","}
+ }
+ }
+ ],
+ "preprocessing": [
+ {
+ "id": "text_preprocessing",
+ "name": "文本预处理",
+ "description": "清洗和规范化文本数据",
+ "config_schema": {
+ "removeStopwords": {"type": "boolean", "description": "去除停用词", "default": True},
+ "removeURLs": {"type": "boolean", "description": "去除URL", "default": True},
+ "removeEmojis": {"type": "boolean", "description": "去除表情符号", "default": False},
+ "lowercase": {"type": "boolean", "description": "转为小写", "default": True}
+ }
+ },
+ {
+ "id": "tokenization",
+ "name": "分词",
+ "description": "将文本切分为词语或标记",
+ "config_schema": {
+ "method": {"type": "string", "description": "分词方法", "default": "jieba"},
+ "pos_tagging": {"type": "boolean", "description": "进行词性标注", "default": False}
+ }
+ },
+ {
+ "id": "feature_extraction",
+ "name": "特征提取",
+ "description": "从文本提取数值特征",
+ "config_schema": {
+ "method": {"type": "string", "description": "特征提取方法", "default": "tfidf"},
+ "max_features": {"type": "integer", "description": "最大特征数", "default": 1000}
+ }
+ }
+ ],
+ "model": [
+ {
+ "id": "sentiment",
+ "name": "情感分析",
+ "description": "分析文本情感倾向",
+ "config_schema": {
+ "api": {"type": "string", "description": "使用的API", "default": "openai"},
+ "model_type": {"type": "string", "description": "模型类型", "default": "sentiment_analysis"},
+ "optimize_for": {"type": "string", "description": "优化目标", "default": "balanced"}
+ }
+ },
+ {
+ "id": "topic_classification",
+ "name": "话题分类",
+ "description": "对文本进行话题分类",
+ "config_schema": {
+ "api": {"type": "string", "description": "使用的API", "default": "deepseek"},
+ "model_type": {"type": "string", "description": "模型类型", "default": "topic_classification"},
+ "optimize_for": {"type": "string", "description": "优化目标", "default": "performance"}
+ }
+ },
+ {
+ "id": "keyword_extraction",
+ "name": "关键词提取",
+ "description": "从文本中提取关键词",
+ "config_schema": {
+ "api": {"type": "string", "description": "使用的API", "default": "openai"},
+ "model_type": {"type": "string", "description": "模型类型", "default": "keyword_extraction"},
+ "optimize_for": {"type": "string", "description": "优化目标", "default": "balanced"}
+ }
+ },
+ {
+ "id": "custom_ai",
+ "name": "自定义AI模型",
+ "description": "使用自定义AI模型进行分析",
+ "config_schema": {
+ "model_path": {"type": "string", "description": "模型路径", "required": True},
+ "model_type": {"type": "string", "description": "模型类型", "required": True}
+ }
+ }
+ ],
+ "visualization": [
+ {
+ "id": "line_chart",
+ "name": "折线图",
+ "description": "展示数据随时间的变化趋势",
+ "config_schema": {
+ "title": {"type": "string", "description": "图表标题", "default": "时间趋势"},
+ "x_axis": {"type": "string", "description": "X轴字段", "default": "time"},
+ "y_axis": {"type": "string", "description": "Y轴字段", "default": "value"},
+ "color": {"type": "string", "description": "线条颜色", "default": "#1890ff"}
+ }
+ },
+ {
+ "id": "bar_chart",
+ "name": "柱状图",
+ "description": "展示不同类别的数据对比",
+ "config_schema": {
+ "title": {"type": "string", "description": "图表标题", "default": "数据对比"},
+ "x_axis": {"type": "string", "description": "X轴字段", "default": "category"},
+ "y_axis": {"type": "string", "description": "Y轴字段", "default": "value"}
+ }
+ },
+ {
+ "id": "pie_chart",
+ "name": "饼图",
+ "description": "展示数据的构成比例",
+ "config_schema": {
+ "title": {"type": "string", "description": "图表标题", "default": "比例分布"},
+ "value_field": {"type": "string", "description": "值字段", "default": "value"},
+ "label_field": {"type": "string", "description": "标签字段", "default": "label"}
+ }
+ },
+ {
+ "id": "wordcloud",
+ "name": "词云图",
+ "description": "直观展示文本中的高频词",
+ "config_schema": {
+ "title": {"type": "string", "description": "图表标题", "default": "关键词云"},
+ "max_words": {"type": "integer", "description": "最大词数", "default": 100},
+ "color_scheme": {"type": "string", "description": "配色方案", "default": "viridis"}
+ }
+ },
+ {
+ "id": "heatmap",
+ "name": "热力图",
+ "description": "展示数据的密度分布",
+ "config_schema": {
+ "title": {"type": "string", "description": "图表标题", "default": "热力分布"},
+ "x_axis": {"type": "string", "description": "X轴字段", "default": "x"},
+ "y_axis": {"type": "string", "description": "Y轴字段", "default": "y"},
+ "value_field": {"type": "string", "description": "值字段", "default": "value"}
+ }
+ }
+ ]
+}
+
+@workflow_bp.route('/crawler-templates', methods=['GET'])
+def get_crawler_templates():
+ """获取爬虫配置模板列表"""
+ # 从缓存获取
+ templates = workflow_cache.get('crawler_templates')
+ if templates is None:
+ # 从数据库获取用户定义的模板
+ db = DatabaseManager.get_connection()
+ cursor = db.cursor()
+ cursor.execute("""
+ SELECT id, name, description, icon, config
+ FROM crawler_templates
+ WHERE deleted = 0
+ ORDER BY created_at DESC
+ """)
+ user_templates = cursor.fetchall()
+ cursor.close()
+
+ # 结合默认模板
+ templates = DEFAULT_CRAWLER_TEMPLATES + list(user_templates)
+
+ # 缓存结果
+ workflow_cache.set('crawler_templates', templates)
+
+ return jsonify({
+ 'success': True,
+ 'data': filter_dict(templates)
+ })
+
+@workflow_bp.route('/crawler-templates/', methods=['GET'])
+def get_crawler_template(template_id):
+ """获取指定爬虫配置模板"""
+ # 查找默认模板
+ for template in DEFAULT_CRAWLER_TEMPLATES:
+ if template['id'] == template_id:
+ return jsonify({
+ 'success': True,
+ 'data': filter_dict(template)
+ })
+
+ # 从数据库查找用户模板
+ db = DatabaseManager.get_connection()
+ cursor = db.cursor()
+ cursor.execute("""
+ SELECT id, name, description, icon, config
+ FROM crawler_templates
+ WHERE id = %s AND deleted = 0
+ """, (template_id,))
+ template = cursor.fetchone()
+ cursor.close()
+
+ if not template:
+ return jsonify({
+ 'success': False,
+ 'message': f"未找到模板: {template_id}"
+ }), 404
+
+ return jsonify({
+ 'success': True,
+ 'data': filter_dict(template)
+ })
+
+@workflow_bp.route('/crawler-templates', methods=['POST'])
+def create_crawler_template():
+ """创建爬虫配置模板"""
+ data = request.json
+ required_fields = ['name', 'description', 'config']
+
+ # 验证必要字段
+ for field in required_fields:
+ if field not in data:
+ return jsonify({
+ 'success': False,
+ 'message': f"缺少必要字段: {field}"
+ }), 400
+
+ # 生成ID
+ template_id = f"template_{int(time.time())}_{uuid.uuid4().hex[:8]}"
+
+ # 准备数据
+ template = {
+ 'id': template_id,
+ 'name': data['name'],
+ 'description': data['description'],
+ 'icon': data.get('icon', 'fas fa-spider'),
+ 'config': data['config'],
+ 'created_at': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
+ 'updated_at': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
+ 'deleted': 0
+ }
+
+ # 保存到数据库
+ db = DatabaseManager.get_connection()
+ cursor = db.cursor()
+ try:
+ cursor.execute("""
+ INSERT INTO crawler_templates
+ (id, name, description, icon, config, created_at, updated_at, deleted)
+ VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
+ """, (
+ template['id'],
+ template['name'],
+ template['description'],
+ template['icon'],
+ json.dumps(template['config']),
+ template['created_at'],
+ template['updated_at'],
+ template['deleted']
+ ))
+ db.commit()
+
+ # 清除缓存
+ workflow_cache.invalidate('crawler_templates')
+
+ return jsonify({
+ 'success': True,
+ 'data': filter_dict(template)
+ }), 201
+ except Exception as e:
+ db.rollback()
+ logger.error(f"创建爬虫模板失败: {e}")
+ return jsonify({
+ 'success': False,
+ 'message': f"创建模板失败: {str(e)}"
+ }), 500
+ finally:
+ cursor.close()
+
+@workflow_bp.route('/crawler-templates/', methods=['PUT'])
+def update_crawler_template(template_id):
+ """更新爬虫配置模板"""
+ data = request.json
+
+ # 验证模板是否存在
+ db = DatabaseManager.get_connection()
+ cursor = db.cursor()
+ cursor.execute("""
+ SELECT id FROM crawler_templates
+ WHERE id = %s AND deleted = 0
+ """, (template_id,))
+ exists = cursor.fetchone()
+
+ if not exists:
+ cursor.close()
+ return jsonify({
+ 'success': False,
+ 'message': f"未找到模板: {template_id}"
+ }), 404
+
+ # 准备更新数据
+ update_data = {}
+ if 'name' in data:
+ update_data['name'] = data['name']
+ if 'description' in data:
+ update_data['description'] = data['description']
+ if 'icon' in data:
+ update_data['icon'] = data['icon']
+ if 'config' in data:
+ update_data['config'] = json.dumps(data['config'])
+
+ update_data['updated_at'] = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
+
+ # 构建SQL语句
+ sql = "UPDATE crawler_templates SET "
+ sql += ", ".join([f"{key} = %s" for key in update_data.keys()])
+ sql += " WHERE id = %s"
+
+ # 执行更新
+ try:
+ cursor.execute(sql, list(update_data.values()) + [template_id])
+ db.commit()
+
+ # 清除缓存
+ workflow_cache.invalidate('crawler_templates')
+
+ return jsonify({
+ 'success': True,
+ 'message': "模板更新成功"
+ })
+ except Exception as e:
+ db.rollback()
+ logger.error(f"更新爬虫模板失败: {e}")
+ return jsonify({
+ 'success': False,
+ 'message': f"更新模板失败: {str(e)}"
+ }), 500
+ finally:
+ cursor.close()
+
+@workflow_bp.route('/crawler-templates/', methods=['DELETE'])
+def delete_crawler_template(template_id):
+ """删除爬虫配置模板"""
+ # 验证模板是否存在
+ db = DatabaseManager.get_connection()
+ cursor = db.cursor()
+ cursor.execute("""
+ SELECT id FROM crawler_templates
+ WHERE id = %s AND deleted = 0
+ """, (template_id,))
+ exists = cursor.fetchone()
+
+ if not exists:
+ cursor.close()
+ return jsonify({
+ 'success': False,
+ 'message': f"未找到模板: {template_id}"
+ }), 404
+
+ # 软删除
+ try:
+ cursor.execute("""
+ UPDATE crawler_templates
+ SET deleted = 1, updated_at = %s
+ WHERE id = %s
+ """, (datetime.now().strftime('%Y-%m-%d %H:%M:%S'), template_id))
+ db.commit()
+
+ # 清除缓存
+ workflow_cache.invalidate('crawler_templates')
+
+ return jsonify({
+ 'success': True,
+ 'message': "模板删除成功"
+ })
+ except Exception as e:
+ db.rollback()
+ logger.error(f"删除爬虫模板失败: {e}")
+ return jsonify({
+ 'success': False,
+ 'message': f"删除模板失败: {str(e)}"
+ }), 500
+ finally:
+ cursor.close()
+
+@workflow_bp.route('/analysis-templates', methods=['GET'])
+def get_analysis_templates():
+ """获取分析流程模板列表"""
+ # 从缓存获取
+ templates = workflow_cache.get('analysis_templates')
+ if templates is None:
+ # 从数据库获取用户定义的模板
+ db = DatabaseManager.get_connection()
+ cursor = db.cursor()
+ cursor.execute("""
+ SELECT id, name, description, icon, components, connections
+ FROM analysis_templates
+ WHERE deleted = 0
+ ORDER BY created_at DESC
+ """)
+ user_templates = cursor.fetchall()
+ cursor.close()
+
+ # 结合默认模板
+ templates = DEFAULT_ANALYSIS_TEMPLATES + list(user_templates)
+
+ # 缓存结果
+ workflow_cache.set('analysis_templates', templates)
+
+ return jsonify({
+ 'success': True,
+ 'data': filter_dict(templates)
+ })
+
+@workflow_bp.route('/analysis-templates/', methods=['GET'])
+def get_analysis_template(template_id):
+ """获取指定分析流程模板"""
+ # 查找默认模板
+ for template in DEFAULT_ANALYSIS_TEMPLATES:
+ if template['id'] == template_id:
+ return jsonify({
+ 'success': True,
+ 'data': filter_dict(template)
+ })
+
+ # 从数据库查找用户模板
+ db = DatabaseManager.get_connection()
+ cursor = db.cursor()
+ cursor.execute("""
+ SELECT id, name, description, icon, components, connections
+ FROM analysis_templates
+ WHERE id = %s AND deleted = 0
+ """, (template_id,))
+ template = cursor.fetchone()
+ cursor.close()
+
+ if not template:
+ return jsonify({
+ 'success': False,
+ 'message': f"未找到模板: {template_id}"
+ }), 404
+
+ return jsonify({
+ 'success': True,
+ 'data': filter_dict(template)
+ })
+
+@workflow_bp.route('/analysis-templates', methods=['POST'])
+def create_analysis_template():
+ """创建分析流程模板"""
+ data = request.json
+ required_fields = ['name', 'description', 'components', 'connections']
+
+ # 验证必要字段
+ for field in required_fields:
+ if field not in data:
+ return jsonify({
+ 'success': False,
+ 'message': f"缺少必要字段: {field}"
+ }), 400
+
+ # 生成ID
+ template_id = f"template_{int(time.time())}_{uuid.uuid4().hex[:8]}"
+
+ # 准备数据
+ template = {
+ 'id': template_id,
+ 'name': data['name'],
+ 'description': data['description'],
+ 'icon': data.get('icon', 'fas fa-chart-line'),
+ 'components': data['components'],
+ 'connections': data['connections'],
+ 'created_at': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
+ 'updated_at': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
+ 'deleted': 0
+ }
+
+ # 保存到数据库
+ db = DatabaseManager.get_connection()
+ cursor = db.cursor()
+ try:
+ cursor.execute("""
+ INSERT INTO analysis_templates
+ (id, name, description, icon, components, connections, created_at, updated_at, deleted)
+ VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)
+ """, (
+ template['id'],
+ template['name'],
+ template['description'],
+ template['icon'],
+ json.dumps(template['components']),
+ json.dumps(template['connections']),
+ template['created_at'],
+ template['updated_at'],
+ template['deleted']
+ ))
+ db.commit()
+
+ # 清除缓存
+ workflow_cache.invalidate('analysis_templates')
+
+ return jsonify({
+ 'success': True,
+ 'data': filter_dict(template)
+ }), 201
+ except Exception as e:
+ db.rollback()
+ logger.error(f"创建分析模板失败: {e}")
+ return jsonify({
+ 'success': False,
+ 'message': f"创建模板失败: {str(e)}"
+ }), 500
+ finally:
+ cursor.close()
+
+@workflow_bp.route('/analysis-templates/', methods=['PUT'])
+def update_analysis_template(template_id):
+ """更新分析流程模板"""
+ data = request.json
+
+ # 验证模板是否存在
+ db = DatabaseManager.get_connection()
+ cursor = db.cursor()
+ cursor.execute("""
+ SELECT id FROM analysis_templates
+ WHERE id = %s AND deleted = 0
+ """, (template_id,))
+ exists = cursor.fetchone()
+
+ if not exists:
+ cursor.close()
+ return jsonify({
+ 'success': False,
+ 'message': f"未找到模板: {template_id}"
+ }), 404
+
+ # 准备更新数据
+ update_data = {}
+ if 'name' in data:
+ update_data['name'] = data['name']
+ if 'description' in data:
+ update_data['description'] = data['description']
+ if 'icon' in data:
+ update_data['icon'] = data['icon']
+ if 'components' in data:
+ update_data['components'] = json.dumps(data['components'])
+ if 'connections' in data:
+ update_data['connections'] = json.dumps(data['connections'])
+
+ update_data['updated_at'] = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
+
+ # 构建SQL语句
+ sql = "UPDATE analysis_templates SET "
+ sql += ", ".join([f"{key} = %s" for key in update_data.keys()])
+ sql += " WHERE id = %s"
+
+ # 执行更新
+ try:
+ cursor.execute(sql, list(update_data.values()) + [template_id])
+ db.commit()
+
+ # 清除缓存
+ workflow_cache.invalidate('analysis_templates')
+
+ return jsonify({
+ 'success': True,
+ 'message': "模板更新成功"
+ })
+ except Exception as e:
+ db.rollback()
+ logger.error(f"更新分析模板失败: {e}")
+ return jsonify({
+ 'success': False,
+ 'message': f"更新模板失败: {str(e)}"
+ }), 500
+ finally:
+ cursor.close()
+
+@workflow_bp.route('/analysis-templates/', methods=['DELETE'])
+def delete_analysis_template(template_id):
+ """删除分析流程模板"""
+ # 验证模板是否存在
+ db = DatabaseManager.get_connection()
+ cursor = db.cursor()
+ cursor.execute("""
+ SELECT id FROM analysis_templates
+ WHERE id = %s AND deleted = 0
+ """, (template_id,))
+ exists = cursor.fetchone()
+
+ if not exists:
+ cursor.close()
+ return jsonify({
+ 'success': False,
+ 'message': f"未找到模板: {template_id}"
+ }), 404
+
+ # 软删除
+ try:
+ cursor.execute("""
+ UPDATE analysis_templates
+ SET deleted = 1, updated_at = %s
+ WHERE id = %s
+ """, (datetime.now().strftime('%Y-%m-%d %H:%M:%S'), template_id))
+ db.commit()
+
+ # 清除缓存
+ workflow_cache.invalidate('analysis_templates')
+
+ return jsonify({
+ 'success': True,
+ 'message': "模板删除成功"
+ })
+ except Exception as e:
+ db.rollback()
+ logger.error(f"删除分析模板失败: {e}")
+ return jsonify({
+ 'success': False,
+ 'message': f"删除模板失败: {str(e)}"
+ }), 500
+ finally:
+ cursor.close()
+
+@workflow_bp.route('/components', methods=['GET'])
+def get_available_components():
+ """获取可用组件列表"""
+ return jsonify({
+ 'success': True,
+ 'data': filter_dict(AVAILABLE_COMPONENTS)
+ })
+
+@workflow_bp.route('/run-workflow', methods=['POST'])
+def run_workflow():
+ """执行工作流"""
+ data = request.json
+
+ # 验证必要字段
+ if 'components' not in data or 'connections' not in data:
+ return jsonify({
+ 'success': False,
+ 'message': "缺少必要字段: components 或 connections"
+ }), 400
+
+ # 这里是执行工作流逻辑的占位符
+ # 实际实现需要根据组件类型和连接关系建立执行计划并执行
+
+ # 记录执行请求
+ logger.info(f"收到工作流执行请求,组件数量: {len(data['components'])}, 连接数量: {len(data['connections'])}")
+
+ # 创建任务ID
+ task_id = f"task_{int(time.time())}_{uuid.uuid4().hex[:8]}"
+
+ # 返回任务ID
+ return jsonify({
+ 'success': True,
+ 'message': "工作流执行请求已提交",
+ 'data': {
+ 'task_id': task_id,
+ 'status': 'pending'
+ }
+ })
+
+@workflow_bp.route('/task-status/', methods=['GET'])
+def get_task_status(task_id):
+ """获取任务执行状态"""
+ # 这里是获取任务状态的占位符
+ # 实际实现需要查询任务执行状态
+
+ # 示例状态
+ status = {
+ 'task_id': task_id,
+ 'status': 'running',
+ 'progress': 45,
+ 'message': "正在执行数据预处理",
+ 'started_at': (datetime.now() - timedelta(minutes=2)).strftime('%Y-%m-%d %H:%M:%S'),
+ 'updated_at': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
+ }
+
+ return jsonify({
+ 'success': True,
+ 'data': status
+ })
\ No newline at end of file