🚀 Major Upgrade! Visual Workflow Orchestrator and AI-Powered Crawler Implemented. Added Model Arena Feature and Efficiency Optimizations (Two-Level Caching Architecture + End-to-End Performance Enhancements).
This commit is contained in:
@@ -83,9 +83,11 @@ app.config['PERMANENT_SESSION_LIFETIME'] = timedelta(hours=2)
|
||||
from views.page import page
|
||||
from views.user import user
|
||||
from views.spider_control import spider_bp
|
||||
from views.workflow_api import workflow_bp
|
||||
app.register_blueprint(page.pb)
|
||||
app.register_blueprint(user.ub)
|
||||
app.register_blueprint(spider_bp)
|
||||
app.register_blueprint(workflow_bp) # 注册工作流蓝图
|
||||
|
||||
# 首页路由
|
||||
@app.route('/')
|
||||
|
||||
+49
-1
@@ -46,4 +46,52 @@ CREATE TABLE `user` (
|
||||
`id` int(11) NOT NULL AUTO_INCREMENT,
|
||||
`createTime` varchar(255) DEFAULT NULL,
|
||||
PRIMARY KEY (`id`)
|
||||
) ENGINE=InnoDB AUTO_INCREMENT=4 DEFAULT CHARSET=utf8;
|
||||
) ENGINE=InnoDB AUTO_INCREMENT=4 DEFAULT CHARSET=utf8;
|
||||
|
||||
-- 爬虫模板表
|
||||
CREATE TABLE IF NOT EXISTS `crawler_templates` (
|
||||
`id` VARCHAR(64) NOT NULL COMMENT '模板ID',
|
||||
`name` VARCHAR(64) NOT NULL COMMENT '模板名称',
|
||||
`description` VARCHAR(255) NULL COMMENT '模板描述',
|
||||
`icon` VARCHAR(32) NULL COMMENT '图标',
|
||||
`config` JSON NOT NULL COMMENT '配置JSON',
|
||||
`created_at` DATETIME NOT NULL COMMENT '创建时间',
|
||||
`updated_at` DATETIME NOT NULL COMMENT '更新时间',
|
||||
`deleted` TINYINT(1) NOT NULL DEFAULT 0 COMMENT '是否删除',
|
||||
PRIMARY KEY (`id`),
|
||||
INDEX `idx_crawler_templates_name` (`name`)
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='爬虫配置模板表';
|
||||
|
||||
-- 分析流程模板表
|
||||
CREATE TABLE IF NOT EXISTS `analysis_templates` (
|
||||
`id` VARCHAR(64) NOT NULL COMMENT '模板ID',
|
||||
`name` VARCHAR(64) NOT NULL COMMENT '模板名称',
|
||||
`description` VARCHAR(255) NULL COMMENT '模板描述',
|
||||
`icon` VARCHAR(32) NULL COMMENT '图标',
|
||||
`components` JSON NOT NULL COMMENT '组件JSON',
|
||||
`connections` JSON NOT NULL COMMENT '连接JSON',
|
||||
`created_at` DATETIME NOT NULL COMMENT '创建时间',
|
||||
`updated_at` DATETIME NOT NULL COMMENT '更新时间',
|
||||
`deleted` TINYINT(1) NOT NULL DEFAULT 0 COMMENT '是否删除',
|
||||
PRIMARY KEY (`id`),
|
||||
INDEX `idx_analysis_templates_name` (`name`)
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='分析流程模板表';
|
||||
|
||||
-- 工作流执行任务表
|
||||
CREATE TABLE IF NOT EXISTS `workflow_tasks` (
|
||||
`id` VARCHAR(64) NOT NULL COMMENT '任务ID',
|
||||
`template_id` VARCHAR(64) NULL COMMENT '关联模板ID',
|
||||
`type` VARCHAR(32) NOT NULL COMMENT '任务类型:crawler/analysis',
|
||||
`status` VARCHAR(16) NOT NULL COMMENT '任务状态:pending/running/completed/failed',
|
||||
`progress` INT(11) NOT NULL DEFAULT 0 COMMENT '进度百分比',
|
||||
`config` JSON NOT NULL COMMENT '任务配置',
|
||||
`result` JSON NULL COMMENT '执行结果',
|
||||
`error` TEXT NULL COMMENT '错误信息',
|
||||
`started_at` DATETIME NULL COMMENT '开始时间',
|
||||
`completed_at` DATETIME NULL COMMENT '完成时间',
|
||||
`created_at` DATETIME NOT NULL COMMENT '创建时间',
|
||||
`updated_at` DATETIME NOT NULL COMMENT '更新时间',
|
||||
PRIMARY KEY (`id`),
|
||||
INDEX `idx_workflow_tasks_type_status` (`type`, `status`),
|
||||
INDEX `idx_workflow_tasks_template` (`template_id`)
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='工作流执行任务表';
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,519 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="zh-CN">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>工作流编辑器 - 微博舆情分析系统</title>
|
||||
<link href="https://cdn.bootcdn.net/ajax/libs/twitter-bootstrap/5.2.3/css/bootstrap.min.css" rel="stylesheet">
|
||||
<link href="https://cdn.bootcdn.net/ajax/libs/font-awesome/6.2.0/css/all.min.css" rel="stylesheet">
|
||||
<link href="https://cdn.jsdelivr.net/npm/jsoneditor@9.5.0/dist/jsoneditor.min.css" rel="stylesheet">
|
||||
<style>
|
||||
:root {
|
||||
--primary-color: #1890ff;
|
||||
--success-color: #52c41a;
|
||||
--warning-color: #faad14;
|
||||
--error-color: #f5222d;
|
||||
--bg-color: #f0f2f5;
|
||||
--component-bg: #fafafa;
|
||||
--border-color: #d9d9d9;
|
||||
}
|
||||
|
||||
body {
|
||||
background-color: var(--bg-color);
|
||||
font-family: "PingFang SC", "Hiragino Sans GB", "Microsoft YaHei", "Helvetica Neue", Helvetica, Arial, sans-serif;
|
||||
}
|
||||
|
||||
.navbar-brand {
|
||||
font-weight: 600;
|
||||
}
|
||||
|
||||
.sidebar {
|
||||
position: fixed;
|
||||
top: 56px;
|
||||
bottom: 0;
|
||||
left: 0;
|
||||
z-index: 100;
|
||||
padding: 20px 0;
|
||||
width: 280px;
|
||||
overflow-x: hidden;
|
||||
overflow-y: auto;
|
||||
background-color: white;
|
||||
border-right: 1px solid var(--border-color);
|
||||
}
|
||||
|
||||
.main-content {
|
||||
margin-left: 280px;
|
||||
padding: 20px;
|
||||
}
|
||||
|
||||
.workflow-canvas {
|
||||
background-color: white;
|
||||
min-height: calc(100vh - 150px);
|
||||
border-radius: 8px;
|
||||
box-shadow: 0 1px 2px rgba(0, 0, 0, 0.05);
|
||||
position: relative;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.component-container {
|
||||
padding: 15px;
|
||||
border-radius: 6px;
|
||||
border: 1px solid var(--border-color);
|
||||
margin-bottom: 15px;
|
||||
background-color: var(--component-bg);
|
||||
}
|
||||
|
||||
.component-item {
|
||||
background-color: white;
|
||||
padding: 8px 12px;
|
||||
border-radius: 4px;
|
||||
margin: 8px 0;
|
||||
border: 1px solid var(--border-color);
|
||||
cursor: move;
|
||||
user-select: none;
|
||||
transition: all 0.3s;
|
||||
}
|
||||
|
||||
.component-item:hover {
|
||||
border-color: var(--primary-color);
|
||||
box-shadow: 0 2px 5px rgba(24, 144, 255, 0.15);
|
||||
}
|
||||
|
||||
.workflow-node {
|
||||
position: absolute;
|
||||
width: 200px;
|
||||
min-height: 100px;
|
||||
background-color: white;
|
||||
border-radius: 6px;
|
||||
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.1);
|
||||
padding: 12px;
|
||||
cursor: move;
|
||||
z-index: 10;
|
||||
}
|
||||
|
||||
.workflow-node .node-header {
|
||||
border-bottom: 1px solid #eee;
|
||||
padding-bottom: 8px;
|
||||
margin-bottom: 8px;
|
||||
}
|
||||
|
||||
.workflow-node .node-title {
|
||||
font-weight: 600;
|
||||
font-size: 14px;
|
||||
display: block;
|
||||
text-overflow: ellipsis;
|
||||
overflow: hidden;
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
.workflow-node .node-type {
|
||||
font-size: 12px;
|
||||
color: #666;
|
||||
}
|
||||
|
||||
.workflow-node .node-ports {
|
||||
position: relative;
|
||||
height: 20px;
|
||||
margin-top: 15px;
|
||||
}
|
||||
|
||||
.port {
|
||||
position: absolute;
|
||||
width: 12px;
|
||||
height: 12px;
|
||||
border-radius: 50%;
|
||||
background-color: var(--primary-color);
|
||||
cursor: crosshair;
|
||||
}
|
||||
|
||||
.port-in {
|
||||
top: 4px;
|
||||
left: -6px;
|
||||
}
|
||||
|
||||
.port-out {
|
||||
top: 4px;
|
||||
right: -6px;
|
||||
}
|
||||
|
||||
.connection-path {
|
||||
stroke: var(--primary-color);
|
||||
stroke-width: 2px;
|
||||
fill: none;
|
||||
}
|
||||
|
||||
.template-item {
|
||||
border: 1px solid var(--border-color);
|
||||
border-radius: 6px;
|
||||
padding: 15px;
|
||||
margin-bottom: 15px;
|
||||
cursor: pointer;
|
||||
transition: all 0.3s;
|
||||
}
|
||||
|
||||
.template-item:hover {
|
||||
border-color: var(--primary-color);
|
||||
box-shadow: 0 2px 8px rgba(24, 144, 255, 0.15);
|
||||
}
|
||||
|
||||
.template-item .template-title {
|
||||
font-weight: 600;
|
||||
color: #333;
|
||||
}
|
||||
|
||||
.template-item .template-desc {
|
||||
color: #666;
|
||||
font-size: 13px;
|
||||
margin-top: 5px;
|
||||
}
|
||||
|
||||
.active-tab {
|
||||
border-bottom: 2px solid var(--primary-color);
|
||||
color: var(--primary-color);
|
||||
font-weight: 600;
|
||||
}
|
||||
|
||||
.tab-content {
|
||||
padding-top: 20px;
|
||||
}
|
||||
|
||||
.task-item {
|
||||
background-color: white;
|
||||
border-radius: 6px;
|
||||
padding: 15px;
|
||||
margin-bottom: 15px;
|
||||
border-left: 4px solid var(--primary-color);
|
||||
}
|
||||
|
||||
.task-item.running {
|
||||
border-left-color: var(--primary-color);
|
||||
}
|
||||
|
||||
.task-item.completed {
|
||||
border-left-color: var(--success-color);
|
||||
}
|
||||
|
||||
.task-item.failed {
|
||||
border-left-color: var(--error-color);
|
||||
}
|
||||
|
||||
.properties-panel {
|
||||
position: fixed;
|
||||
top: 76px;
|
||||
right: 20px;
|
||||
width: 320px;
|
||||
background-color: white;
|
||||
border-radius: 8px;
|
||||
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.1);
|
||||
padding: 15px;
|
||||
max-height: calc(100vh - 120px);
|
||||
overflow-y: auto;
|
||||
z-index: 100;
|
||||
transform: translateX(360px);
|
||||
transition: transform 0.3s;
|
||||
}
|
||||
|
||||
.properties-panel.open {
|
||||
transform: translateX(0);
|
||||
}
|
||||
|
||||
.form-label {
|
||||
font-weight: 500;
|
||||
font-size: 13px;
|
||||
}
|
||||
|
||||
/* 媒体查询用于响应式设计 */
|
||||
@media (max-width: 768px) {
|
||||
.sidebar {
|
||||
width: 100%;
|
||||
position: static;
|
||||
height: auto;
|
||||
padding-bottom: 0;
|
||||
}
|
||||
|
||||
.main-content {
|
||||
margin-left: 0;
|
||||
}
|
||||
|
||||
.properties-panel {
|
||||
width: 100%;
|
||||
position: fixed;
|
||||
top: auto;
|
||||
bottom: 0;
|
||||
right: 0;
|
||||
transform: translateY(100%);
|
||||
border-radius: 8px 8px 0 0;
|
||||
max-height: 70vh;
|
||||
}
|
||||
|
||||
.properties-panel.open {
|
||||
transform: translateY(0);
|
||||
}
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<nav class="navbar navbar-expand-lg navbar-dark bg-dark">
|
||||
<div class="container-fluid">
|
||||
<a class="navbar-brand" href="#">
|
||||
<i class="fas fa-project-diagram me-2"></i>工作流编辑器
|
||||
</a>
|
||||
<button class="navbar-toggler" type="button" data-bs-toggle="collapse" data-bs-target="#navbarNav">
|
||||
<span class="navbar-toggler-icon"></span>
|
||||
</button>
|
||||
<div class="collapse navbar-collapse" id="navbarNav">
|
||||
<ul class="navbar-nav me-auto">
|
||||
<li class="nav-item">
|
||||
<a class="nav-link active" href="#">可视化编辑</a>
|
||||
</li>
|
||||
<li class="nav-item">
|
||||
<a class="nav-link" href="#">模板管理</a>
|
||||
</li>
|
||||
<li class="nav-item">
|
||||
<a class="nav-link" href="#">任务列表</a>
|
||||
</li>
|
||||
</ul>
|
||||
<div class="d-flex">
|
||||
<button id="saveWorkflowBtn" class="btn btn-success me-2">
|
||||
<i class="fas fa-save me-1"></i>保存
|
||||
</button>
|
||||
<button id="runWorkflowBtn" class="btn btn-primary">
|
||||
<i class="fas fa-play me-1"></i>运行
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</nav>
|
||||
|
||||
<div class="container-fluid">
|
||||
<div class="row">
|
||||
<!-- 侧边栏 -->
|
||||
<div class="col-md-3 col-lg-2 d-md-block sidebar">
|
||||
<div class="d-flex justify-content-center mb-4">
|
||||
<div class="btn-group">
|
||||
<button class="btn btn-outline-primary active" id="componentsTabBtn">组件</button>
|
||||
<button class="btn btn-outline-primary" id="templatesTabBtn">模板</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 组件面板 -->
|
||||
<div id="componentsPanel">
|
||||
<div class="component-container">
|
||||
<h6><i class="fas fa-database me-2"></i>数据源</h6>
|
||||
<div class="component-list">
|
||||
<div class="component-item" data-type="data_source" data-subtype="database">
|
||||
<i class="fas fa-table me-2"></i>数据库
|
||||
</div>
|
||||
<div class="component-item" data-type="data_source" data-subtype="file">
|
||||
<i class="fas fa-file-alt me-2"></i>文件
|
||||
</div>
|
||||
<div class="component-item" data-type="data_source" data-subtype="crawler">
|
||||
<i class="fas fa-spider me-2"></i>爬虫
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="component-container">
|
||||
<h6><i class="fas fa-filter me-2"></i>数据处理</h6>
|
||||
<div class="component-list">
|
||||
<div class="component-item" data-type="preprocessing" data-subtype="filter">
|
||||
<i class="fas fa-filter me-2"></i>过滤
|
||||
</div>
|
||||
<div class="component-item" data-type="preprocessing" data-subtype="sort">
|
||||
<i class="fas fa-sort me-2"></i>排序
|
||||
</div>
|
||||
<div class="component-item" data-type="preprocessing" data-subtype="aggregate">
|
||||
<i class="fas fa-layer-group me-2"></i>聚合
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="component-container">
|
||||
<h6><i class="fas fa-brain me-2"></i>模型分析</h6>
|
||||
<div class="component-list">
|
||||
<div class="component-item" data-type="model" data-subtype="sentiment">
|
||||
<i class="fas fa-smile me-2"></i>情感分析
|
||||
</div>
|
||||
<div class="component-item" data-type="model" data-subtype="topic">
|
||||
<i class="fas fa-tags me-2"></i>话题分类
|
||||
</div>
|
||||
<div class="component-item" data-type="model" data-subtype="keywords">
|
||||
<i class="fas fa-key me-2"></i>关键词提取
|
||||
</div>
|
||||
<div class="component-item" data-type="model" data-subtype="summarize">
|
||||
<i class="fas fa-compress-alt me-2"></i>文本摘要
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="component-container">
|
||||
<h6><i class="fas fa-chart-bar me-2"></i>可视化</h6>
|
||||
<div class="component-list">
|
||||
<div class="component-item" data-type="visualization" data-subtype="chart">
|
||||
<i class="fas fa-chart-line me-2"></i>图表
|
||||
</div>
|
||||
<div class="component-item" data-type="visualization" data-subtype="table">
|
||||
<i class="fas fa-table me-2"></i>表格
|
||||
</div>
|
||||
<div class="component-item" data-type="visualization" data-subtype="wordcloud">
|
||||
<i class="fas fa-cloud me-2"></i>词云
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 模板面板 -->
|
||||
<div id="templatesPanel" style="display: none;">
|
||||
<div class="d-flex justify-content-between align-items-center mb-3">
|
||||
<h6 class="mb-0">爬虫模板</h6>
|
||||
<button class="btn btn-sm btn-outline-primary">
|
||||
<i class="fas fa-plus"></i> 新建
|
||||
</button>
|
||||
</div>
|
||||
<div id="crawlerTemplatesList">
|
||||
<!-- 爬虫模板列表将动态加载 -->
|
||||
</div>
|
||||
|
||||
<div class="d-flex justify-content-between align-items-center mb-3 mt-4">
|
||||
<h6 class="mb-0">分析流程模板</h6>
|
||||
<button class="btn btn-sm btn-outline-primary">
|
||||
<i class="fas fa-plus"></i> 新建
|
||||
</button>
|
||||
</div>
|
||||
<div id="analysisTemplatesList">
|
||||
<!-- 分析流程模板列表将动态加载 -->
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 主要内容 -->
|
||||
<div class="col-md-9 col-lg-10 main-content">
|
||||
<div class="workflow-canvas" id="workflowCanvas">
|
||||
<!-- 工作流节点和连接将在这里动态创建 -->
|
||||
<svg id="connectionsSvg" style="position: absolute; top: 0; left: 0; width: 100%; height: 100%; pointer-events: none;">
|
||||
<!-- 连接线将在这里动态创建 -->
|
||||
</svg>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 属性面板 -->
|
||||
<div class="properties-panel" id="propertiesPanel">
|
||||
<div class="d-flex justify-content-between align-items-center mb-3">
|
||||
<h5 class="mb-0">组件属性</h5>
|
||||
<button class="btn-close" id="closePropertiesBtn"></button>
|
||||
</div>
|
||||
<div id="propertiesContent">
|
||||
<!-- 属性内容将动态加载 -->
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 模态框 -->
|
||||
<div class="modal fade" id="saveTemplateModal" tabindex="-1">
|
||||
<div class="modal-dialog">
|
||||
<div class="modal-content">
|
||||
<div class="modal-header">
|
||||
<h5 class="modal-title">保存为模板</h5>
|
||||
<button type="button" class="btn-close" data-bs-dismiss="modal"></button>
|
||||
</div>
|
||||
<div class="modal-body">
|
||||
<form id="saveTemplateForm">
|
||||
<div class="mb-3">
|
||||
<label for="templateName" class="form-label">模板名称</label>
|
||||
<input type="text" class="form-control" id="templateName" required>
|
||||
</div>
|
||||
<div class="mb-3">
|
||||
<label for="templateDescription" class="form-label">描述</label>
|
||||
<textarea class="form-control" id="templateDescription" rows="3"></textarea>
|
||||
</div>
|
||||
<div class="mb-3">
|
||||
<label for="templateIcon" class="form-label">图标</label>
|
||||
<select class="form-select" id="templateIcon">
|
||||
<option value="chart-line">📊 图表</option>
|
||||
<option value="filter">🔍 过滤</option>
|
||||
<option value="spider">🕸️ 爬虫</option>
|
||||
<option value="brain">🧠 AI分析</option>
|
||||
<option value="database">💾 数据库</option>
|
||||
<option value="cloud">☁️ 词云</option>
|
||||
</select>
|
||||
</div>
|
||||
</form>
|
||||
</div>
|
||||
<div class="modal-footer">
|
||||
<button type="button" class="btn btn-secondary" data-bs-dismiss="modal">取消</button>
|
||||
<button type="button" class="btn btn-primary" id="saveTemplateBtn">保存</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="modal fade" id="runWorkflowModal" tabindex="-1">
|
||||
<div class="modal-dialog">
|
||||
<div class="modal-content">
|
||||
<div class="modal-header">
|
||||
<h5 class="modal-title">运行工作流</h5>
|
||||
<button type="button" class="btn-close" data-bs-dismiss="modal"></button>
|
||||
</div>
|
||||
<div class="modal-body">
|
||||
<p>确认要运行当前工作流吗?</p>
|
||||
<div class="form-check mb-3">
|
||||
<input class="form-check-input" type="checkbox" id="saveBeforeRun" checked>
|
||||
<label class="form-check-label" for="saveBeforeRun">
|
||||
运行前保存工作流
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
<div class="modal-footer">
|
||||
<button type="button" class="btn btn-secondary" data-bs-dismiss="modal">取消</button>
|
||||
<button type="button" class="btn btn-primary" id="confirmRunBtn">运行</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="modal fade" id="taskStatusModal" tabindex="-1">
|
||||
<div class="modal-dialog modal-lg">
|
||||
<div class="modal-content">
|
||||
<div class="modal-header">
|
||||
<h5 class="modal-title">任务执行状态</h5>
|
||||
<button type="button" class="btn-close" data-bs-dismiss="modal"></button>
|
||||
</div>
|
||||
<div class="modal-body">
|
||||
<div class="mb-3">
|
||||
<h6>进度</h6>
|
||||
<div class="progress">
|
||||
<div id="taskProgressBar" class="progress-bar" role="progressbar" style="width: 0%"></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="mb-3">
|
||||
<h6>状态信息</h6>
|
||||
<div id="taskStatusInfo" class="p-3 bg-light rounded">
|
||||
<p class="mb-1">任务ID: <span id="taskIdDisplay">-</span></p>
|
||||
<p class="mb-1">状态: <span id="taskStatusDisplay">-</span></p>
|
||||
<p class="mb-1">开始时间: <span id="taskStartTimeDisplay">-</span></p>
|
||||
<p class="mb-0">完成时间: <span id="taskCompleteTimeDisplay">-</span></p>
|
||||
</div>
|
||||
</div>
|
||||
<div>
|
||||
<h6>结果预览</h6>
|
||||
<div id="taskResultPreview" class="p-3 bg-light rounded" style="max-height: 300px; overflow: auto;">
|
||||
<p class="text-muted">任务完成后将显示结果预览...</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="modal-footer">
|
||||
<button type="button" class="btn btn-secondary" data-bs-dismiss="modal">关闭</button>
|
||||
<button type="button" class="btn btn-danger" id="cancelTaskBtn">取消任务</button>
|
||||
<button type="button" class="btn btn-primary" id="viewResultBtn">查看完整结果</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script src="https://cdn.bootcdn.net/ajax/libs/jquery/3.6.1/jquery.min.js"></script>
|
||||
<script src="https://cdn.bootcdn.net/ajax/libs/twitter-bootstrap/5.2.3/js/bootstrap.bundle.min.js"></script>
|
||||
<script src="https://cdn.jsdelivr.net/npm/jsoneditor@9.5.0/dist/jsoneditor.min.js"></script>
|
||||
<script src="C:\Users\67093\Desktop\开源安全奖励计划\网安小学期-微博舆情预测系统\Weibo_PublicOpinion_AnalysisSystem\static\js\workflow_editor.js"></script>
|
||||
</body>
|
||||
</html>
|
||||
+262
-77
@@ -1,116 +1,301 @@
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import shutil
|
||||
from datetime import datetime, timedelta
|
||||
import threading
|
||||
import queue
|
||||
from collections import OrderedDict
|
||||
import pickle
|
||||
import hashlib
|
||||
import logging
|
||||
|
||||
class PredictionCache:
|
||||
logger = logging.getLogger('cache_manager')
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
class LRUCache:
|
||||
"""实现LRU (Least Recently Used) 缓存策略"""
|
||||
|
||||
def __init__(self, capacity):
|
||||
self.cache = OrderedDict()
|
||||
self.capacity = capacity
|
||||
|
||||
def get(self, key):
|
||||
if key not in self.cache:
|
||||
return None
|
||||
# 访问元素时,将其移至末尾,表示最近使用
|
||||
self.cache.move_to_end(key)
|
||||
return self.cache[key]
|
||||
|
||||
def put(self, key, value):
|
||||
# 如果键已存在,更新值并将其移至末尾
|
||||
if key in self.cache:
|
||||
self.cache[key] = value
|
||||
self.cache.move_to_end(key)
|
||||
return
|
||||
|
||||
# 如果缓存已满,删除最久未使用的项(OrderedDict 的首项)
|
||||
if len(self.cache) >= self.capacity:
|
||||
self.cache.popitem(last=False)
|
||||
|
||||
# 添加新项至末尾
|
||||
self.cache[key] = value
|
||||
|
||||
def remove(self, key):
|
||||
if key in self.cache:
|
||||
del self.cache[key]
|
||||
|
||||
def clear(self):
|
||||
self.cache.clear()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.cache)
|
||||
|
||||
def get_all_keys(self):
|
||||
return list(self.cache.keys())
|
||||
|
||||
|
||||
class CacheManager:
|
||||
"""两级缓存系统:内存LRU缓存 + 磁盘持久化缓存"""
|
||||
|
||||
_instance = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
def __new__(cls):
|
||||
def __new__(cls, *args, **kwargs):
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = super(PredictionCache, cls).__new__(cls)
|
||||
cls._instance = super(CacheManager, cls).__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(self, 'initialized'):
|
||||
self.cache_dir = 'cache/predictions'
|
||||
self.cache_duration = timedelta(hours=24) # 缓存24小时
|
||||
self.cache = {}
|
||||
self.cache_queue = queue.Queue()
|
||||
def __init__(self, name="default", memory_capacity=1000, cache_duration=24,
|
||||
disk_cache_dir="cache", flush_interval=5):
|
||||
if hasattr(self, 'initialized'):
|
||||
return
|
||||
|
||||
self.name = name
|
||||
self.memory_cache = LRUCache(memory_capacity)
|
||||
self.disk_cache_dir = os.path.join(disk_cache_dir, name)
|
||||
self.cache_duration = timedelta(hours=cache_duration)
|
||||
self.flush_interval = flush_interval # 定时将内存缓存刷新到磁盘的间隔(分钟)
|
||||
self.cache_stats = {"hits": 0, "misses": 0, "disk_hits": 0}
|
||||
self.disk_queue = queue.Queue()
|
||||
self.initialized = True
|
||||
|
||||
# 确保缓存目录存在
|
||||
os.makedirs(self.cache_dir, exist_ok=True)
|
||||
os.makedirs(self.disk_cache_dir, exist_ok=True)
|
||||
|
||||
# 启动缓存清理线程
|
||||
self.cleanup_thread = threading.Thread(target=self._cleanup_old_cache, daemon=True)
|
||||
# 启动缓存管理线程
|
||||
self.cleanup_thread = threading.Thread(target=self._cleanup_and_flush_task, daemon=True)
|
||||
self.cleanup_thread.start()
|
||||
|
||||
# 加载现有缓存
|
||||
self._load_cache()
|
||||
# 启动磁盘写入线程
|
||||
self.disk_writer_thread = threading.Thread(target=self._disk_writer_task, daemon=True)
|
||||
self.disk_writer_thread.start()
|
||||
|
||||
logger.info(f"初始化缓存管理器: {name},内存容量: {memory_capacity}项,缓存时间: {cache_duration}小时")
|
||||
|
||||
def _load_cache(self):
|
||||
"""加载磁盘上的缓存文件"""
|
||||
try:
|
||||
for filename in os.listdir(self.cache_dir):
|
||||
if filename.endswith('.json'):
|
||||
filepath = os.path.join(self.cache_dir, filename)
|
||||
with open(filepath, 'r', encoding='utf-8') as f:
|
||||
cache_data = json.load(f)
|
||||
# 检查缓存是否过期
|
||||
if self._is_cache_valid(cache_data['timestamp']):
|
||||
topic = filename[:-5] # 移除.json后缀
|
||||
self.cache[topic] = cache_data
|
||||
else:
|
||||
# 删除过期缓存文件
|
||||
os.remove(filepath)
|
||||
except Exception as e:
|
||||
print(f"加载缓存失败: {e}")
|
||||
def _get_cache_key(self, key):
|
||||
"""标准化缓存键"""
|
||||
if isinstance(key, str):
|
||||
return key
|
||||
return hashlib.md5(str(key).encode()).hexdigest()
|
||||
|
||||
def _cleanup_old_cache(self):
|
||||
"""定期清理过期缓存的后台线程"""
|
||||
while True:
|
||||
try:
|
||||
# 检查并清理内存缓存
|
||||
current_time = datetime.now()
|
||||
expired_topics = []
|
||||
|
||||
for topic, cache_data in self.cache.items():
|
||||
if not self._is_cache_valid(cache_data['timestamp']):
|
||||
expired_topics.append(topic)
|
||||
|
||||
# 删除过期缓存
|
||||
for topic in expired_topics:
|
||||
del self.cache[topic]
|
||||
cache_file = os.path.join(self.cache_dir, f"{topic}.json")
|
||||
if os.path.exists(cache_file):
|
||||
os.remove(cache_file)
|
||||
|
||||
# 休眠1小时后再次检查
|
||||
time.sleep(3600)
|
||||
except Exception as e:
|
||||
print(f"清理缓存时出错: {e}")
|
||||
time.sleep(3600) # 发生错误时也等待1小时
|
||||
def _get_disk_path(self, key):
|
||||
"""获取磁盘缓存路径"""
|
||||
safe_key = self._get_cache_key(key)
|
||||
return os.path.join(self.disk_cache_dir, f"{safe_key}.cache")
|
||||
|
||||
def _is_cache_valid(self, timestamp):
|
||||
"""检查缓存是否有效"""
|
||||
"""检查缓存是否过期"""
|
||||
cache_time = datetime.fromtimestamp(timestamp)
|
||||
return datetime.now() - cache_time < self.cache_duration
|
||||
|
||||
def get(self, topic):
|
||||
"""获取话题的预测缓存"""
|
||||
if topic in self.cache and self._is_cache_valid(self.cache[topic]['timestamp']):
|
||||
return self.cache[topic]['prediction']
|
||||
def get(self, key):
|
||||
"""获取缓存数据,首先检查内存,然后检查磁盘"""
|
||||
cache_key = self._get_cache_key(key)
|
||||
|
||||
# 1. 检查内存缓存
|
||||
cache_data = self.memory_cache.get(cache_key)
|
||||
if cache_data is not None:
|
||||
if self._is_cache_valid(cache_data['timestamp']):
|
||||
self.cache_stats["hits"] += 1
|
||||
logger.debug(f"内存缓存命中: {key}")
|
||||
return cache_data['data']
|
||||
else:
|
||||
# 过期缓存,从内存中删除
|
||||
self.memory_cache.remove(cache_key)
|
||||
|
||||
# 2. 检查磁盘缓存
|
||||
disk_path = self._get_disk_path(cache_key)
|
||||
if os.path.exists(disk_path):
|
||||
try:
|
||||
with open(disk_path, 'rb') as f:
|
||||
cache_data = pickle.load(f)
|
||||
|
||||
if self._is_cache_valid(cache_data['timestamp']):
|
||||
# 从磁盘加载后,放入内存缓存
|
||||
self.memory_cache.put(cache_key, cache_data)
|
||||
self.cache_stats["disk_hits"] += 1
|
||||
logger.debug(f"磁盘缓存命中: {key}")
|
||||
return cache_data['data']
|
||||
else:
|
||||
# 过期缓存,删除磁盘文件
|
||||
os.remove(disk_path)
|
||||
except Exception as e:
|
||||
logger.warning(f"读取磁盘缓存失败: {key}, 错误: {e}")
|
||||
|
||||
self.cache_stats["misses"] += 1
|
||||
logger.debug(f"缓存未命中: {key}")
|
||||
return None
|
||||
|
||||
def set(self, topic, prediction):
|
||||
"""设置话题的预测缓存"""
|
||||
def set(self, key, data, immediate_disk_write=False):
|
||||
"""设置缓存数据,同时更新内存和安排磁盘写入"""
|
||||
cache_key = self._get_cache_key(key)
|
||||
cache_data = {
|
||||
'prediction': prediction,
|
||||
'data': data,
|
||||
'timestamp': datetime.now().timestamp()
|
||||
}
|
||||
|
||||
# 更新内存缓存
|
||||
self.cache[topic] = cache_data
|
||||
self.memory_cache.put(cache_key, cache_data)
|
||||
|
||||
# 异步保存到磁盘
|
||||
self.cache_queue.put((topic, cache_data))
|
||||
threading.Thread(target=self._save_cache_to_disk, daemon=True).start()
|
||||
# 安排写入磁盘
|
||||
if immediate_disk_write:
|
||||
self._write_to_disk(cache_key, cache_data)
|
||||
else:
|
||||
self.disk_queue.put((cache_key, cache_data))
|
||||
|
||||
logger.debug(f"缓存已设置: {key}")
|
||||
return True
|
||||
|
||||
def _save_cache_to_disk(self):
|
||||
"""异步保存缓存到磁盘"""
|
||||
def invalidate(self, key):
|
||||
"""使指定键的缓存失效"""
|
||||
cache_key = self._get_cache_key(key)
|
||||
|
||||
# 从内存中删除
|
||||
self.memory_cache.remove(cache_key)
|
||||
|
||||
# 从磁盘中删除
|
||||
disk_path = self._get_disk_path(cache_key)
|
||||
if os.path.exists(disk_path):
|
||||
try:
|
||||
os.remove(disk_path)
|
||||
logger.debug(f"缓存已失效: {key}")
|
||||
except Exception as e:
|
||||
logger.warning(f"删除磁盘缓存失败: {key}, 错误: {e}")
|
||||
|
||||
return True
|
||||
|
||||
def clear_all(self):
|
||||
"""清除所有缓存"""
|
||||
# 清除内存缓存
|
||||
self.memory_cache.clear()
|
||||
|
||||
# 清除磁盘缓存
|
||||
try:
|
||||
while not self.cache_queue.empty():
|
||||
topic, cache_data = self.cache_queue.get()
|
||||
cache_file = os.path.join(self.cache_dir, f"{topic}.json")
|
||||
with open(cache_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(cache_data, f, ensure_ascii=False, indent=2)
|
||||
shutil.rmtree(self.disk_cache_dir)
|
||||
os.makedirs(self.disk_cache_dir, exist_ok=True)
|
||||
logger.info(f"所有缓存已清除: {self.name}")
|
||||
except Exception as e:
|
||||
print(f"保存缓存到磁盘失败: {e}")
|
||||
logger.error(f"清除磁盘缓存失败: {e}")
|
||||
|
||||
# 重置统计信息
|
||||
self.cache_stats = {"hits": 0, "misses": 0, "disk_hits": 0}
|
||||
|
||||
return True
|
||||
|
||||
def get_stats(self):
|
||||
"""获取缓存统计信息"""
|
||||
total_requests = self.cache_stats["hits"] + self.cache_stats["misses"]
|
||||
hit_rate = (self.cache_stats["hits"] / total_requests * 100) if total_requests > 0 else 0
|
||||
total_hits = self.cache_stats["hits"] + self.cache_stats["disk_hits"]
|
||||
|
||||
memory_size = len(self.memory_cache)
|
||||
disk_size = len([f for f in os.listdir(self.disk_cache_dir) if f.endswith('.cache')])
|
||||
|
||||
return {
|
||||
"name": self.name,
|
||||
"memory_items": memory_size,
|
||||
"disk_items": disk_size,
|
||||
"memory_hits": self.cache_stats["hits"],
|
||||
"disk_hits": self.cache_stats["disk_hits"],
|
||||
"misses": self.cache_stats["misses"],
|
||||
"total_requests": total_requests,
|
||||
"hit_rate": hit_rate,
|
||||
"two_level_hit_rate": (total_hits / total_requests * 100) if total_requests > 0 else 0
|
||||
}
|
||||
|
||||
def _write_to_disk(self, cache_key, cache_data):
|
||||
"""将缓存写入磁盘"""
|
||||
disk_path = self._get_disk_path(cache_key)
|
||||
try:
|
||||
with open(disk_path, 'wb') as f:
|
||||
pickle.dump(cache_data, f)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"写入磁盘缓存失败: {cache_key}, 错误: {e}")
|
||||
return False
|
||||
|
||||
def _disk_writer_task(self):
|
||||
"""后台线程,负责将缓存写入磁盘"""
|
||||
while True:
|
||||
try:
|
||||
# 尝试从队列获取条目,超时后继续循环
|
||||
try:
|
||||
cache_key, cache_data = self.disk_queue.get(timeout=1)
|
||||
self._write_to_disk(cache_key, cache_data)
|
||||
self.disk_queue.task_done()
|
||||
except queue.Empty:
|
||||
time.sleep(0.1)
|
||||
except Exception as e:
|
||||
logger.error(f"磁盘写入线程出错: {e}")
|
||||
time.sleep(5) # 发生错误时等待一段时间
|
||||
|
||||
def _cleanup_and_flush_task(self):
|
||||
"""后台线程,负责清理过期缓存和定期刷新内存缓存到磁盘"""
|
||||
while True:
|
||||
try:
|
||||
# 1. 清理过期的内存缓存
|
||||
current_time = datetime.now()
|
||||
for key in self.memory_cache.get_all_keys():
|
||||
cache_data = self.memory_cache.get(key)
|
||||
if not self._is_cache_valid(cache_data['timestamp']):
|
||||
self.memory_cache.remove(key)
|
||||
|
||||
# 2. 清理过期的磁盘缓存
|
||||
for filename in os.listdir(self.disk_cache_dir):
|
||||
if filename.endswith('.cache'):
|
||||
filepath = os.path.join(self.disk_cache_dir, filename)
|
||||
try:
|
||||
with open(filepath, 'rb') as f:
|
||||
cache_data = pickle.load(f)
|
||||
if not self._is_cache_valid(cache_data['timestamp']):
|
||||
os.remove(filepath)
|
||||
except Exception as e:
|
||||
# 清理损坏的缓存文件
|
||||
logger.warning(f"读取缓存文件失败,将删除: {filepath}, 错误: {e}")
|
||||
os.remove(filepath)
|
||||
|
||||
# 3. 将内存缓存刷新到磁盘
|
||||
# 注意:这会重写已经写入磁盘的缓存,但确保内存和磁盘保持同步
|
||||
for key in self.memory_cache.get_all_keys():
|
||||
cache_data = self.memory_cache.get(key)
|
||||
self._write_to_disk(key, cache_data)
|
||||
|
||||
# 每小时执行一次清理
|
||||
time.sleep(3600)
|
||||
except Exception as e:
|
||||
logger.error(f"缓存清理线程出错: {e}")
|
||||
time.sleep(3600) # 发生错误时也等待一段时间
|
||||
|
||||
# 创建全局缓存实例
|
||||
prediction_cache = PredictionCache()
|
||||
|
||||
# 创建不同领域的缓存实例
|
||||
prediction_cache = CacheManager(name="predictions", memory_capacity=500, cache_duration=24)
|
||||
sentiment_cache = CacheManager(name="sentiment", memory_capacity=1000, cache_duration=12)
|
||||
topic_cache = CacheManager(name="topics", memory_capacity=200, cache_duration=6)
|
||||
user_data_cache = CacheManager(name="user_data", memory_capacity=300, cache_duration=48)
|
||||
|
||||
# 向后兼容的别名
|
||||
PredictionCache = CacheManager
|
||||
# 为保持向后兼容,我们保留原来的prediction_cache
|
||||
prediction_cache_old = prediction_cache
|
||||
@@ -0,0 +1,558 @@
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import getpass
|
||||
import secrets
|
||||
import logging
|
||||
import platform
|
||||
import socket
|
||||
import hashlib
|
||||
import base64
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
import pymysql
|
||||
from dotenv import load_dotenv, set_key, find_dotenv
|
||||
|
||||
# 设置日志
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
handlers=[logging.StreamHandler(sys.stdout)]
|
||||
)
|
||||
logger = logging.getLogger('init_wizard')
|
||||
|
||||
class InitWizard:
|
||||
"""
|
||||
初始化向导 - 简化系统的初始配置流程,并提供安全加固功能
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# 加载环境变量
|
||||
load_dotenv()
|
||||
|
||||
# 配置项
|
||||
self.config = {
|
||||
# 数据库配置
|
||||
'db': {
|
||||
'host': os.getenv('DB_HOST', 'localhost'),
|
||||
'port': int(os.getenv('DB_PORT', '3306')),
|
||||
'user': os.getenv('DB_USER', 'root'),
|
||||
'password': os.getenv('DB_PASSWORD', ''),
|
||||
'database': os.getenv('DB_NAME', 'Weibo_PublicOpinion_AnalysisSystem'),
|
||||
'ssl': bool(os.getenv('DB_SSL', 'false').lower() == 'true')
|
||||
},
|
||||
# Flask应用配置
|
||||
'app': {
|
||||
'host': os.getenv('FLASK_HOST', '127.0.0.1'),
|
||||
'port': int(os.getenv('FLASK_PORT', '5000')),
|
||||
'secret_key': os.getenv('FLASK_SECRET_KEY', ''),
|
||||
'enable_https': bool(os.getenv('ENABLE_HTTPS', 'false').lower() == 'true'),
|
||||
'debug': bool(os.getenv('FLASK_DEBUG', 'false').lower() == 'true')
|
||||
},
|
||||
# API密钥配置
|
||||
'api_keys': {
|
||||
'openai': os.getenv('OPENAI_API_KEY', ''),
|
||||
'anthropic': os.getenv('ANTHROPIC_API_KEY', ''),
|
||||
'deepseek': os.getenv('DEEPSEEK_API_KEY', '')
|
||||
},
|
||||
# 安全配置
|
||||
'security': {
|
||||
'enable_rate_limit': bool(os.getenv('ENABLE_RATE_LIMIT', 'true').lower() == 'true'),
|
||||
'enable_ip_blocking': bool(os.getenv('ENABLE_IP_BLOCKING', 'true').lower() == 'true'),
|
||||
'enable_sensitive_data_filter': bool(os.getenv('ENABLE_SENSITIVE_DATA_FILTER', 'true').lower() == 'true'),
|
||||
'enable_mutual_auth': bool(os.getenv('ENABLE_MUTUAL_AUTH', 'false').lower() == 'true'),
|
||||
'min_password_length': int(os.getenv('MIN_PASSWORD_LENGTH', '8')),
|
||||
'session_timeout': int(os.getenv('SESSION_TIMEOUT', '120')), # 分钟
|
||||
},
|
||||
# 爬虫配置
|
||||
'crawler': {
|
||||
'interval': int(os.getenv('CRAWL_INTERVAL', '18000')), # 秒
|
||||
'max_retries': int(os.getenv('CRAWL_MAX_RETRIES', '3')),
|
||||
'timeout': int(os.getenv('CRAWL_TIMEOUT', '30')),
|
||||
'max_concurrent': int(os.getenv('CRAWL_MAX_CONCURRENT', '2')),
|
||||
'user_agent': os.getenv('CRAWL_USER_AGENT', 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36')
|
||||
},
|
||||
# 系统配置
|
||||
'system': {
|
||||
'initialized': bool(os.getenv('SYSTEM_INITIALIZED', 'false').lower() == 'true'),
|
||||
'version': os.getenv('SYSTEM_VERSION', '2.0.0'),
|
||||
'log_level': os.getenv('LOG_LEVEL', 'INFO'),
|
||||
'data_dir': os.getenv('DATA_DIR', 'data'),
|
||||
'temp_dir': os.getenv('TEMP_DIR', 'temp'),
|
||||
'cache_dir': os.getenv('CACHE_DIR', 'cache'),
|
||||
'max_model_memory': float(os.getenv('MAX_MODEL_MEMORY_USAGE', '4.0')), # GB
|
||||
}
|
||||
}
|
||||
|
||||
# 安全选项
|
||||
self.security_options = {
|
||||
'rate_limit': {
|
||||
'name': '请求速率限制',
|
||||
'description': '防止API被滥用,限制单个IP的请求频率',
|
||||
'default': True
|
||||
},
|
||||
'ip_blocking': {
|
||||
'name': 'IP黑名单',
|
||||
'description': '阻止可疑IP访问系统',
|
||||
'default': True
|
||||
},
|
||||
'sensitive_data_filter': {
|
||||
'name': '敏感信息过滤',
|
||||
'description': '自动识别并屏蔽输出内容中的敏感信息(如手机号、邮箱等)',
|
||||
'default': True
|
||||
},
|
||||
'mutual_auth': {
|
||||
'name': '双向认证',
|
||||
'description': '要求API调用方提供有效证书,增强API安全性(需要HTTPS)',
|
||||
'default': False
|
||||
}
|
||||
}
|
||||
|
||||
def start(self):
|
||||
"""启动初始化向导"""
|
||||
self._print_welcome()
|
||||
|
||||
if self.config['system']['initialized']:
|
||||
print("\n系统已经初始化过。您想重新配置吗? [y/N]: ", end='')
|
||||
choice = input().strip().lower()
|
||||
if choice != 'y':
|
||||
print("初始化向导已退出。如需重新配置,请设置环境变量 SYSTEM_INITIALIZED=false 或删除 .env 文件。")
|
||||
return
|
||||
|
||||
# 主配置流程
|
||||
try:
|
||||
self._configure_database()
|
||||
self._configure_app()
|
||||
self._configure_api_keys()
|
||||
self._configure_security()
|
||||
self._configure_crawler()
|
||||
self._configure_system()
|
||||
|
||||
# 保存配置
|
||||
self._save_config()
|
||||
|
||||
# 应用安全措施
|
||||
self._apply_security_measures()
|
||||
|
||||
print("\n✅ 初始化完成!系统已成功配置。")
|
||||
print("您现在可以运行 python app.py 启动应用。")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\n初始化向导已取消。配置未保存。")
|
||||
except Exception as e:
|
||||
logger.error(f"初始化过程中发生错误: {e}")
|
||||
print(f"\n❌ 初始化失败: {e}")
|
||||
print("请检查错误并重试。")
|
||||
|
||||
def _print_welcome(self):
|
||||
"""打印欢迎信息"""
|
||||
print("\n" + "="*80)
|
||||
print(" "*20 + "微博舆情分析预测系统 - 初始化向导 v2.0")
|
||||
print("="*80)
|
||||
print("\n欢迎使用微博舆情分析预测系统!此向导将引导您完成系统的初始配置。")
|
||||
print("按Ctrl+C可随时退出向导。")
|
||||
print("\n系统信息:")
|
||||
print(f" • 操作系统: {platform.system()} {platform.release()}")
|
||||
print(f" • Python版本: {platform.python_version()}")
|
||||
print(f" • 主机名: {socket.gethostname()}")
|
||||
print(f" • 当前时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
print("\n让我们开始配置吧!每个选项都有默认值,直接按回车即可使用默认值。")
|
||||
print("-"*80)
|
||||
|
||||
def _configure_database(self):
|
||||
"""配置数据库连接"""
|
||||
print("\n📦 数据库配置")
|
||||
print("-"*50)
|
||||
|
||||
# 询问数据库连接信息
|
||||
self.config['db']['host'] = self._prompt(
|
||||
"数据库主机", self.config['db']['host'])
|
||||
|
||||
port_str = self._prompt(
|
||||
"数据库端口", str(self.config['db']['port']))
|
||||
try:
|
||||
self.config['db']['port'] = int(port_str)
|
||||
except ValueError:
|
||||
print(f"端口号无效,使用默认值 {self.config['db']['port']}")
|
||||
|
||||
self.config['db']['user'] = self._prompt(
|
||||
"数据库用户名", self.config['db']['user'])
|
||||
|
||||
# 密码使用getpass以避免明文显示
|
||||
default_pass = '*' * len(self.config['db']['password']) if self.config['db']['password'] else ''
|
||||
password = getpass.getpass(f"数据库密码 [{default_pass}]: ")
|
||||
if password:
|
||||
self.config['db']['password'] = password
|
||||
|
||||
self.config['db']['database'] = self._prompt(
|
||||
"数据库名", self.config['db']['database'])
|
||||
|
||||
ssl_str = self._prompt(
|
||||
"使用SSL连接 (true/false)", str(self.config['db']['ssl']).lower())
|
||||
self.config['db']['ssl'] = ssl_str.lower() == 'true'
|
||||
|
||||
# 测试数据库连接
|
||||
print("\n正在测试数据库连接...")
|
||||
try:
|
||||
self._test_db_connection()
|
||||
print("✅ 数据库连接成功!")
|
||||
except Exception as e:
|
||||
print(f"❌ 数据库连接失败: {e}")
|
||||
retry = input("是否重新配置数据库连接? [Y/n]: ").strip().lower()
|
||||
if retry != 'n':
|
||||
return self._configure_database()
|
||||
else:
|
||||
print("跳过数据库连接测试,但配置可能不正确。")
|
||||
|
||||
def _configure_app(self):
|
||||
"""配置Flask应用"""
|
||||
print("\n🚀 应用配置")
|
||||
print("-"*50)
|
||||
|
||||
self.config['app']['host'] = self._prompt(
|
||||
"监听地址 (0.0.0.0表示所有网络接口)", self.config['app']['host'])
|
||||
|
||||
port_str = self._prompt(
|
||||
"监听端口", str(self.config['app']['port']))
|
||||
try:
|
||||
self.config['app']['port'] = int(port_str)
|
||||
except ValueError:
|
||||
print(f"端口号无效,使用默认值 {self.config['app']['port']}")
|
||||
|
||||
# 自动生成密钥
|
||||
if not self.config['app']['secret_key']:
|
||||
self.config['app']['secret_key'] = secrets.token_hex(32)
|
||||
print(f"已自动生成应用密钥: {self.config['app']['secret_key'][:8]}...")
|
||||
else:
|
||||
regenerate = input("应用密钥已存在。是否重新生成? [y/N]: ").strip().lower()
|
||||
if regenerate == 'y':
|
||||
self.config['app']['secret_key'] = secrets.token_hex(32)
|
||||
print(f"已重新生成应用密钥: {self.config['app']['secret_key'][:8]}...")
|
||||
|
||||
https_str = self._prompt(
|
||||
"启用HTTPS (true/false)", str(self.config['app']['enable_https']).lower())
|
||||
self.config['app']['enable_https'] = https_str.lower() == 'true'
|
||||
|
||||
debug_str = self._prompt(
|
||||
"启用调试模式 (true/false, 生产环境建议false)", str(self.config['app']['debug']).lower())
|
||||
self.config['app']['debug'] = debug_str.lower() == 'true'
|
||||
|
||||
def _configure_api_keys(self):
|
||||
"""配置API密钥"""
|
||||
print("\n🔑 API密钥配置")
|
||||
print("-"*50)
|
||||
print("系统支持多个大语言模型,至少需要配置一个API密钥。")
|
||||
|
||||
# 配置OpenAI API密钥
|
||||
has_openai = self._prompt(
|
||||
"是否配置OpenAI API密钥? (y/n)", "y" if self.config['api_keys']['openai'] else "n")
|
||||
if has_openai.lower() == 'y':
|
||||
self.config['api_keys']['openai'] = self._prompt(
|
||||
"OpenAI API密钥", self.config['api_keys']['openai'])
|
||||
|
||||
# 配置Anthropic API密钥
|
||||
has_anthropic = self._prompt(
|
||||
"是否配置Anthropic (Claude) API密钥? (y/n)", "y" if self.config['api_keys']['anthropic'] else "n")
|
||||
if has_anthropic.lower() == 'y':
|
||||
self.config['api_keys']['anthropic'] = self._prompt(
|
||||
"Anthropic API密钥", self.config['api_keys']['anthropic'])
|
||||
|
||||
# 配置DeepSeek API密钥
|
||||
has_deepseek = self._prompt(
|
||||
"是否配置DeepSeek API密钥? (y/n)", "y" if self.config['api_keys']['deepseek'] else "n")
|
||||
if has_deepseek.lower() == 'y':
|
||||
self.config['api_keys']['deepseek'] = self._prompt(
|
||||
"DeepSeek API密钥", self.config['api_keys']['deepseek'])
|
||||
|
||||
# 检查是否至少配置了一个API密钥
|
||||
if not (self.config['api_keys']['openai'] or self.config['api_keys']['anthropic'] or self.config['api_keys']['deepseek']):
|
||||
print("⚠️ 警告: 您未配置任何API密钥,系统的AI分析功能将不可用。")
|
||||
confirm = input("是否继续? [Y/n]: ").strip().lower()
|
||||
if confirm == 'n':
|
||||
return self._configure_api_keys()
|
||||
|
||||
def _configure_security(self):
|
||||
"""配置安全设置"""
|
||||
print("\n🔒 安全配置")
|
||||
print("-"*50)
|
||||
|
||||
for key, option in self.security_options.items():
|
||||
current_value = self.config['security'][f'enable_{key}']
|
||||
print(f"\n{option['name']}: {option['description']}")
|
||||
enable_str = self._prompt(
|
||||
f"启用{option['name']} (true/false)", str(current_value).lower())
|
||||
self.config['security'][f'enable_{key}'] = enable_str.lower() == 'true'
|
||||
|
||||
# 密码安全策略
|
||||
min_len_str = self._prompt(
|
||||
"最小密码长度 (推荐不低于8)", str(self.config['security']['min_password_length']))
|
||||
try:
|
||||
self.config['security']['min_password_length'] = int(min_len_str)
|
||||
if self.config['security']['min_password_length'] < 6:
|
||||
print("⚠️ 警告: 短密码容易被暴力破解,建议设置更长的密码。")
|
||||
except ValueError:
|
||||
print(f"无效输入,使用默认值 {self.config['security']['min_password_length']}")
|
||||
|
||||
# 会话超时设置
|
||||
timeout_str = self._prompt(
|
||||
"会话超时时间 (分钟)", str(self.config['security']['session_timeout']))
|
||||
try:
|
||||
self.config['security']['session_timeout'] = int(timeout_str)
|
||||
except ValueError:
|
||||
print(f"无效输入,使用默认值 {self.config['security']['session_timeout']}")
|
||||
|
||||
def _configure_crawler(self):
|
||||
"""配置爬虫设置"""
|
||||
print("\n🕷️ 爬虫配置")
|
||||
print("-"*50)
|
||||
|
||||
interval_str = self._prompt(
|
||||
"爬取间隔 (秒)", str(self.config['crawler']['interval']))
|
||||
try:
|
||||
self.config['crawler']['interval'] = int(interval_str)
|
||||
except ValueError:
|
||||
print(f"无效输入,使用默认值 {self.config['crawler']['interval']}")
|
||||
|
||||
retries_str = self._prompt(
|
||||
"最大重试次数", str(self.config['crawler']['max_retries']))
|
||||
try:
|
||||
self.config['crawler']['max_retries'] = int(retries_str)
|
||||
except ValueError:
|
||||
print(f"无效输入,使用默认值 {self.config['crawler']['max_retries']}")
|
||||
|
||||
timeout_str = self._prompt(
|
||||
"超时时间 (秒)", str(self.config['crawler']['timeout']))
|
||||
try:
|
||||
self.config['crawler']['timeout'] = int(timeout_str)
|
||||
except ValueError:
|
||||
print(f"无效输入,使用默认值 {self.config['crawler']['timeout']}")
|
||||
|
||||
concurrent_str = self._prompt(
|
||||
"最大并发数", str(self.config['crawler']['max_concurrent']))
|
||||
try:
|
||||
self.config['crawler']['max_concurrent'] = int(concurrent_str)
|
||||
except ValueError:
|
||||
print(f"无效输入,使用默认值 {self.config['crawler']['max_concurrent']}")
|
||||
|
||||
self.config['crawler']['user_agent'] = self._prompt(
|
||||
"User-Agent", self.config['crawler']['user_agent'])
|
||||
|
||||
def _configure_system(self):
|
||||
"""配置系统设置"""
|
||||
print("\n⚙️ 系统配置")
|
||||
print("-"*50)
|
||||
|
||||
# 日志级别
|
||||
log_levels = ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL']
|
||||
current_level = self.config['system']['log_level']
|
||||
print(f"可选日志级别: {', '.join(log_levels)}")
|
||||
log_level = self._prompt("日志级别", current_level).upper()
|
||||
if log_level in log_levels:
|
||||
self.config['system']['log_level'] = log_level
|
||||
else:
|
||||
print(f"无效的日志级别,使用默认值 {current_level}")
|
||||
|
||||
# 数据目录
|
||||
data_dir = self._prompt("数据目录", self.config['system']['data_dir'])
|
||||
if data_dir:
|
||||
self.config['system']['data_dir'] = data_dir
|
||||
os.makedirs(data_dir, exist_ok=True)
|
||||
print(f"已创建数据目录: {data_dir}")
|
||||
|
||||
# 缓存目录
|
||||
cache_dir = self._prompt("缓存目录", self.config['system']['cache_dir'])
|
||||
if cache_dir:
|
||||
self.config['system']['cache_dir'] = cache_dir
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
print(f"已创建缓存目录: {cache_dir}")
|
||||
|
||||
# 临时目录
|
||||
temp_dir = self._prompt("临时文件目录", self.config['system']['temp_dir'])
|
||||
if temp_dir:
|
||||
self.config['system']['temp_dir'] = temp_dir
|
||||
os.makedirs(temp_dir, exist_ok=True)
|
||||
print(f"已创建临时文件目录: {temp_dir}")
|
||||
|
||||
# 模型内存限制
|
||||
memory_str = self._prompt(
|
||||
"最大模型内存使用量 (GB)", str(self.config['system']['max_model_memory']))
|
||||
try:
|
||||
self.config['system']['max_model_memory'] = float(memory_str)
|
||||
except ValueError:
|
||||
print(f"无效输入,使用默认值 {self.config['system']['max_model_memory']}")
|
||||
|
||||
# 标记系统已初始化
|
||||
self.config['system']['initialized'] = True
|
||||
|
||||
def _save_config(self):
|
||||
"""保存配置到.env文件"""
|
||||
print("\n正在保存配置...")
|
||||
|
||||
# 构建.env文件内容
|
||||
env_content = [
|
||||
"# 微博舆情分析预测系统配置文件",
|
||||
f"# 生成时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}",
|
||||
"",
|
||||
"# 数据库配置",
|
||||
f"DB_HOST={self.config['db']['host']}",
|
||||
f"DB_PORT={self.config['db']['port']}",
|
||||
f"DB_USER={self.config['db']['user']}",
|
||||
f"DB_PASSWORD={self.config['db']['password']}",
|
||||
f"DB_NAME={self.config['db']['database']}",
|
||||
f"DB_SSL={str(self.config['db']['ssl']).lower()}",
|
||||
"",
|
||||
"# 应用配置",
|
||||
f"FLASK_HOST={self.config['app']['host']}",
|
||||
f"FLASK_PORT={self.config['app']['port']}",
|
||||
f"FLASK_SECRET_KEY={self.config['app']['secret_key']}",
|
||||
f"ENABLE_HTTPS={str(self.config['app']['enable_https']).lower()}",
|
||||
f"FLASK_DEBUG={str(self.config['app']['debug']).lower()}",
|
||||
"",
|
||||
"# API密钥",
|
||||
f"OPENAI_API_KEY={self.config['api_keys']['openai']}",
|
||||
f"ANTHROPIC_API_KEY={self.config['api_keys']['anthropic']}",
|
||||
f"DEEPSEEK_API_KEY={self.config['api_keys']['deepseek']}",
|
||||
"",
|
||||
"# 安全配置",
|
||||
f"ENABLE_RATE_LIMIT={str(self.config['security']['enable_rate_limit']).lower()}",
|
||||
f"ENABLE_IP_BLOCKING={str(self.config['security']['enable_ip_blocking']).lower()}",
|
||||
f"ENABLE_SENSITIVE_DATA_FILTER={str(self.config['security']['enable_sensitive_data_filter']).lower()}",
|
||||
f"ENABLE_MUTUAL_AUTH={str(self.config['security']['enable_mutual_auth']).lower()}",
|
||||
f"MIN_PASSWORD_LENGTH={self.config['security']['min_password_length']}",
|
||||
f"SESSION_TIMEOUT={self.config['security']['session_timeout']}",
|
||||
"",
|
||||
"# 爬虫配置",
|
||||
f"CRAWL_INTERVAL={self.config['crawler']['interval']}",
|
||||
f"CRAWL_MAX_RETRIES={self.config['crawler']['max_retries']}",
|
||||
f"CRAWL_TIMEOUT={self.config['crawler']['timeout']}",
|
||||
f"CRAWL_MAX_CONCURRENT={self.config['crawler']['max_concurrent']}",
|
||||
f"CRAWL_USER_AGENT={self.config['crawler']['user_agent']}",
|
||||
"",
|
||||
"# 系统配置",
|
||||
f"SYSTEM_INITIALIZED={str(self.config['system']['initialized']).lower()}",
|
||||
f"SYSTEM_VERSION={self.config['system']['version']}",
|
||||
f"LOG_LEVEL={self.config['system']['log_level']}",
|
||||
f"DATA_DIR={self.config['system']['data_dir']}",
|
||||
f"TEMP_DIR={self.config['system']['temp_dir']}",
|
||||
f"CACHE_DIR={self.config['system']['cache_dir']}",
|
||||
f"MAX_MODEL_MEMORY_USAGE={self.config['system']['max_model_memory']}",
|
||||
]
|
||||
|
||||
# 写入.env文件
|
||||
with open('.env', 'w') as f:
|
||||
f.write('\n'.join(env_content))
|
||||
|
||||
print("✅ 配置已保存到 .env 文件")
|
||||
|
||||
# 创建备份
|
||||
backup_path = f".env.backup.{datetime.now().strftime('%Y%m%d%H%M%S')}"
|
||||
shutil.copy2('.env', backup_path)
|
||||
print(f"✅ 配置备份已保存到 {backup_path}")
|
||||
|
||||
def _test_db_connection(self):
|
||||
"""测试数据库连接"""
|
||||
connection = pymysql.connect(
|
||||
host=self.config['db']['host'],
|
||||
port=self.config['db']['port'],
|
||||
user=self.config['db']['user'],
|
||||
password=self.config['db']['password'],
|
||||
database=self.config['db']['database'],
|
||||
charset='utf8mb4',
|
||||
ssl={'ssl': {'ca': None}} if self.config['db']['ssl'] else None
|
||||
)
|
||||
connection.close()
|
||||
|
||||
def _apply_security_measures(self):
|
||||
"""应用安全措施"""
|
||||
print("\n正在应用安全措施...")
|
||||
|
||||
# 创建相关目录
|
||||
security_dir = os.path.join(self.config['system']['data_dir'], 'security')
|
||||
os.makedirs(security_dir, exist_ok=True)
|
||||
|
||||
# 设置文件权限
|
||||
try:
|
||||
# 仅在类Unix系统上设置文件权限
|
||||
if platform.system() != "Windows":
|
||||
os.chmod('.env', 0o600) # 只有所有者可读写
|
||||
print("✅ 已设置.env文件权限为600 (只有所有者可读写)")
|
||||
except Exception as e:
|
||||
logger.warning(f"设置文件权限失败: {e}")
|
||||
|
||||
# 生成密钥对(如果启用了双向认证)
|
||||
if self.config['security']['enable_mutual_auth']:
|
||||
cert_dir = os.path.join(security_dir, 'certs')
|
||||
os.makedirs(cert_dir, exist_ok=True)
|
||||
|
||||
try:
|
||||
# 检查是否有OpenSSL可用
|
||||
subprocess.run(['openssl', 'version'], check=True, capture_output=True)
|
||||
|
||||
# 生成自签名证书
|
||||
key_file = os.path.join(cert_dir, 'server.key')
|
||||
cert_file = os.path.join(cert_dir, 'server.crt')
|
||||
|
||||
if not os.path.exists(key_file) or not os.path.exists(cert_file):
|
||||
print("正在生成SSL证书...")
|
||||
subprocess.run([
|
||||
'openssl', 'req', '-x509', '-newkey', 'rsa:4096',
|
||||
'-keyout', key_file, '-out', cert_file,
|
||||
'-days', '365', '-nodes',
|
||||
'-subj', '/CN=localhost'
|
||||
], check=True)
|
||||
print(f"✅ SSL证书已生成: {cert_file}")
|
||||
except subprocess.CalledProcessError:
|
||||
print("⚠️ OpenSSL不可用,无法生成SSL证书。如需使用HTTPS,请手动配置证书。")
|
||||
except Exception as e:
|
||||
logger.warning(f"生成SSL证书失败: {e}")
|
||||
|
||||
# 创建敏感信息过滤器配置
|
||||
if self.config['security']['enable_sensitive_data_filter']:
|
||||
filter_config = {
|
||||
'enabled': True,
|
||||
'patterns': {
|
||||
'phone': r'\b1[3-9]\d{9}\b',
|
||||
'email': r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b',
|
||||
'id_card': r'\b[1-9]\d{5}(19|20)\d{2}(0[1-9]|1[0-2])(0[1-9]|[12]\d|3[01])\d{3}[\dXx]\b',
|
||||
'credit_card': r'\b\d{4}[ -]?\d{4}[ -]?\d{4}[ -]?\d{4}\b',
|
||||
'address': r'(北京|上海|广州|深圳|天津|重庆|南京|杭州|武汉|成都|西安)市.*?(路|街|道|巷).*?(号)'
|
||||
},
|
||||
'replacements': {
|
||||
'phone': '***********',
|
||||
'email': '******@*****',
|
||||
'id_card': '******************',
|
||||
'credit_card': '****************',
|
||||
'address': '[地址已隐藏]'
|
||||
}
|
||||
}
|
||||
|
||||
filter_path = os.path.join(security_dir, 'sensitive_filter.json')
|
||||
with open(filter_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(filter_config, f, ensure_ascii=False, indent=2)
|
||||
|
||||
print(f"✅ 敏感信息过滤器配置已保存到 {filter_path}")
|
||||
|
||||
# 创建IP黑名单文件
|
||||
if self.config['security']['enable_ip_blocking']:
|
||||
blacklist_path = os.path.join(security_dir, 'ip_blacklist.txt')
|
||||
if not os.path.exists(blacklist_path):
|
||||
with open(blacklist_path, 'w') as f:
|
||||
f.write("# 每行一个IP地址\n")
|
||||
print(f"✅ IP黑名单文件已创建: {blacklist_path}")
|
||||
|
||||
def _prompt(self, prompt, default=""):
|
||||
"""提示用户输入,如果用户直接按回车则返回默认值"""
|
||||
if default:
|
||||
user_input = input(f"{prompt} [{default}]: ").strip()
|
||||
else:
|
||||
user_input = input(f"{prompt}: ").strip()
|
||||
|
||||
return user_input if user_input else default
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
wizard = InitWizard()
|
||||
wizard.start()
|
||||
@@ -0,0 +1,285 @@
|
||||
import os
|
||||
import sys
|
||||
import pickle
|
||||
import marshal
|
||||
import types
|
||||
import logging
|
||||
import torch
|
||||
import numpy as np
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger('model_loader')
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
def load_sentiment_model(model_path, device=None):
|
||||
"""
|
||||
加载情感分析模型
|
||||
|
||||
参数:
|
||||
model_path: 模型文件路径
|
||||
device: 设备(可忽略,marshal模型不依赖设备)
|
||||
|
||||
返回:
|
||||
加载好的模型对象
|
||||
"""
|
||||
try:
|
||||
logger.info(f"加载情感分析模型: {model_path}")
|
||||
|
||||
if model_path.endswith('.marshal') or model_path.endswith('.marshal.3'):
|
||||
with open(model_path, 'rb') as f:
|
||||
model_data = marshal.load(f)
|
||||
|
||||
# 将marshal数据转换为可调用的函数对象
|
||||
sentiment_func = types.FunctionType(model_data, globals(), "sentiment_func")
|
||||
logger.info("情感分析模型加载成功")
|
||||
return sentiment_func
|
||||
else:
|
||||
raise ValueError(f"不支持的情感模型格式: {model_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"加载情感分析模型失败: {e}")
|
||||
raise
|
||||
|
||||
def load_bert_ctm_model(model_dir, device='cuda' if torch.cuda.is_available() else 'cpu'):
|
||||
"""
|
||||
加载BERT-CTM模型
|
||||
|
||||
参数:
|
||||
model_dir: 模型目录
|
||||
device: 计算设备
|
||||
|
||||
返回:
|
||||
包含模型和分词器的字典
|
||||
"""
|
||||
try:
|
||||
logger.info(f"加载BERT-CTM模型: {model_dir}")
|
||||
|
||||
sys.path.append('model_pro')
|
||||
from BERT_CTM import BERT_CTM
|
||||
from transformers import BertTokenizer
|
||||
|
||||
# 加载模型
|
||||
model_path = os.path.join(model_dir, 'final_model.pt') if not model_dir.endswith('.pt') else model_dir
|
||||
model = BERT_CTM()
|
||||
model.load_state_dict(torch.load(model_path, map_location=device))
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
# 加载分词器
|
||||
tokenizer_path = os.path.join(os.path.dirname(model_dir), 'bert_model')
|
||||
tokenizer = BertTokenizer.from_pretrained(tokenizer_path)
|
||||
|
||||
logger.info("BERT-CTM模型加载成功")
|
||||
return {
|
||||
'model': model,
|
||||
'tokenizer': tokenizer,
|
||||
'device': device
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"加载BERT-CTM模型失败: {e}")
|
||||
raise
|
||||
|
||||
def load_bcat_model(model_dir, device='cuda' if torch.cuda.is_available() else 'cpu'):
|
||||
"""
|
||||
加载BCAT模型
|
||||
|
||||
参数:
|
||||
model_dir: 模型目录
|
||||
device: 计算设备
|
||||
|
||||
返回:
|
||||
包含模型和分词器的字典
|
||||
"""
|
||||
try:
|
||||
logger.info(f"加载BCAT模型: {model_dir}")
|
||||
|
||||
sys.path.append('model_pro')
|
||||
from BCAT import BCAT
|
||||
from transformers import BertTokenizer
|
||||
|
||||
# 加载模型配置
|
||||
config_path = os.path.join(model_dir, 'config.json')
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
config = json.load(f)
|
||||
|
||||
# 初始化模型
|
||||
model = BCAT(**config)
|
||||
|
||||
# 加载模型权重
|
||||
model_path = os.path.join(model_dir, 'model.pt')
|
||||
model.load_state_dict(torch.load(model_path, map_location=device))
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
# 加载分词器
|
||||
tokenizer_path = os.path.join(model_dir, 'tokenizer')
|
||||
tokenizer = BertTokenizer.from_pretrained(tokenizer_path)
|
||||
|
||||
logger.info("BCAT模型加载成功")
|
||||
return {
|
||||
'model': model,
|
||||
'tokenizer': tokenizer,
|
||||
'device': device,
|
||||
'config': config
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"加载BCAT模型失败: {e}")
|
||||
raise
|
||||
|
||||
def load_topic_classifier(model_dir, device='cuda' if torch.cuda.is_available() else 'cpu'):
|
||||
"""
|
||||
加载话题分类模型
|
||||
|
||||
参数:
|
||||
model_dir: 模型目录
|
||||
device: 计算设备
|
||||
|
||||
返回:
|
||||
包含模型、分词器和标签映射的字典
|
||||
"""
|
||||
try:
|
||||
logger.info(f"加载话题分类模型: {model_dir}")
|
||||
|
||||
# 尝试加载transformers模型
|
||||
try:
|
||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||
|
||||
# 加载模型
|
||||
model = AutoModelForSequenceClassification.from_pretrained(model_dir)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
# 加载分词器
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
||||
|
||||
# 加载标签映射
|
||||
labels_path = os.path.join(model_dir, 'labels.json')
|
||||
if os.path.exists(labels_path):
|
||||
with open(labels_path, 'r', encoding='utf-8') as f:
|
||||
labels_map = json.load(f)
|
||||
else:
|
||||
# 尝试从config中读取标签
|
||||
if hasattr(model.config, 'id2label'):
|
||||
labels_map = model.config.id2label
|
||||
else:
|
||||
labels_map = {}
|
||||
|
||||
logger.info("话题分类模型加载成功 (transformers)")
|
||||
return {
|
||||
'model': model,
|
||||
'tokenizer': tokenizer,
|
||||
'labels_map': labels_map,
|
||||
'device': device
|
||||
}
|
||||
except Exception as e:
|
||||
logger.warning(f"使用transformers加载失败,尝试其他方法: {e}")
|
||||
|
||||
# 尝试加载PyTorch模型
|
||||
model_path = os.path.join(model_dir, 'model.pt')
|
||||
if os.path.exists(model_path):
|
||||
model = torch.load(model_path, map_location=device)
|
||||
|
||||
# 加载分词器
|
||||
tokenizer_path = os.path.join(model_dir, 'tokenizer.pkl')
|
||||
if os.path.exists(tokenizer_path):
|
||||
with open(tokenizer_path, 'rb') as f:
|
||||
tokenizer = pickle.load(f)
|
||||
else:
|
||||
tokenizer = None
|
||||
|
||||
# 加载标签映射
|
||||
labels_path = os.path.join(model_dir, 'labels.json')
|
||||
if os.path.exists(labels_path):
|
||||
with open(labels_path, 'r', encoding='utf-8') as f:
|
||||
labels_map = json.load(f)
|
||||
else:
|
||||
labels_map = {}
|
||||
|
||||
logger.info("话题分类模型加载成功 (PyTorch)")
|
||||
return {
|
||||
'model': model,
|
||||
'tokenizer': tokenizer,
|
||||
'labels_map': labels_map,
|
||||
'device': device
|
||||
}
|
||||
|
||||
raise ValueError(f"无法加载模型: {model_dir}")
|
||||
except Exception as e:
|
||||
logger.error(f"加载话题分类模型失败: {e}")
|
||||
raise
|
||||
|
||||
def load_echarts_optimizer():
|
||||
"""
|
||||
加载ECharts优化器,用于提升大数据渲染性能
|
||||
|
||||
返回:
|
||||
ECharts优化器对象
|
||||
"""
|
||||
try:
|
||||
class EChartsOptimizer:
|
||||
def __init__(self):
|
||||
self.chunk_size = 1000 # 分块大小
|
||||
logger.info("ECharts优化器初始化成功")
|
||||
|
||||
def optimize_option(self, option):
|
||||
"""优化ECharts配置,提升大数据渲染性能"""
|
||||
if not option:
|
||||
return option
|
||||
|
||||
# 深拷贝以避免修改原始对象
|
||||
import copy
|
||||
option = copy.deepcopy(option)
|
||||
|
||||
# 添加渐进式渲染
|
||||
if 'progressive' not in option:
|
||||
option['progressive'] = 300 # 每帧渲染的数据点数量
|
||||
|
||||
if 'progressiveThreshold' not in option:
|
||||
option['progressiveThreshold'] = 5000 # 启动渐进式渲染的阈值
|
||||
|
||||
if 'series' in option and isinstance(option['series'], list):
|
||||
for series in option['series']:
|
||||
# 对大数据系列应用优化
|
||||
if 'data' in series and isinstance(series['data'], list) and len(series['data']) > 5000:
|
||||
# 大数据采样
|
||||
if series.get('type') in ['scatter', 'line']:
|
||||
self._optimize_large_data_series(series)
|
||||
|
||||
return option
|
||||
|
||||
def _optimize_large_data_series(self, series):
|
||||
"""优化大数据系列"""
|
||||
# 添加大数据优化选项
|
||||
series['large'] = True
|
||||
series['largeThreshold'] = 2000
|
||||
|
||||
# 按需设置抽样
|
||||
if len(series['data']) > 50000:
|
||||
# 对非常大的数据集进行抽样
|
||||
step = max(1, len(series['data']) // 50000)
|
||||
series['data'] = series['data'][::step]
|
||||
series['sampling'] = 'average'
|
||||
|
||||
return series
|
||||
|
||||
def chunk_process_data(self, data, process_func):
|
||||
"""分块处理大数据"""
|
||||
result = []
|
||||
for i in range(0, len(data), self.chunk_size):
|
||||
chunk = data[i:i + self.chunk_size]
|
||||
result.extend(process_func(chunk))
|
||||
return result
|
||||
|
||||
return EChartsOptimizer()
|
||||
except Exception as e:
|
||||
logger.error(f"加载ECharts优化器失败: {e}")
|
||||
return None
|
||||
|
||||
# 导出所有加载函数
|
||||
__all__ = [
|
||||
'load_sentiment_model',
|
||||
'load_bert_ctm_model',
|
||||
'load_bcat_model',
|
||||
'load_topic_classifier',
|
||||
'load_echarts_optimizer'
|
||||
]
|
||||
@@ -0,0 +1,364 @@
|
||||
import os
|
||||
import time
|
||||
import threading
|
||||
import logging
|
||||
import gc
|
||||
import torch
|
||||
import numpy as np
|
||||
from collections import OrderedDict
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
logger = logging.getLogger('model_manager')
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
class ModelManager:
|
||||
"""
|
||||
模型管理器 - 实现模型预加载和按需卸载技术
|
||||
|
||||
功能:
|
||||
1. 预加载经常使用的模型,减少加载等待时间
|
||||
2. 使用LRU (Least Recently Used) 策略管理内存中加载的模型
|
||||
3. 支持模型的异步加载和监控
|
||||
4. 自动检测并释放长时间未使用的模型内存
|
||||
5. 提供模型使用统计
|
||||
"""
|
||||
|
||||
_instance = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
def __new__(cls):
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = super(ModelManager, cls).__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if hasattr(self, 'initialized'):
|
||||
return
|
||||
|
||||
# 已加载模型的缓存,使用OrderedDict实现LRU
|
||||
self.loaded_models = OrderedDict()
|
||||
# 模型使用统计
|
||||
self.model_stats = {}
|
||||
# 模型预热配置
|
||||
self.preload_config = {}
|
||||
# 最大内存占用(GB)
|
||||
self.max_memory_usage = float(os.getenv('MAX_MODEL_MEMORY_USAGE', '4.0'))
|
||||
# 模型加载中的锁
|
||||
self.loading_locks = {}
|
||||
# 模型卸载超时(分钟)
|
||||
self.unload_timeout = int(os.getenv('MODEL_UNLOAD_TIMEOUT', '30'))
|
||||
|
||||
# 启动模型监控线程
|
||||
self.monitor_thread = threading.Thread(target=self._monitor_models, daemon=True)
|
||||
self.monitor_thread.start()
|
||||
|
||||
self.initialized = True
|
||||
logger.info(f"模型管理器初始化完成,最大内存占用: {self.max_memory_usage}GB")
|
||||
|
||||
def register_model(self, model_id, model_path, preload=False, model_size_gb=0.5,
|
||||
load_function=None, device='cuda' if torch.cuda.is_available() else 'cpu'):
|
||||
"""
|
||||
注册模型,可选设置为预加载
|
||||
|
||||
参数:
|
||||
model_id: 模型唯一标识符
|
||||
model_path: 模型路径
|
||||
preload: 是否预加载
|
||||
model_size_gb: 模型估计大小(GB)
|
||||
load_function: 自定义加载函数,签名为 load_function(model_path, device) -> model
|
||||
device: 加载模型的设备
|
||||
"""
|
||||
self.preload_config[model_id] = {
|
||||
'model_path': model_path,
|
||||
'preload': preload,
|
||||
'model_size_gb': model_size_gb,
|
||||
'load_function': load_function,
|
||||
'device': device
|
||||
}
|
||||
|
||||
self.model_stats[model_id] = {
|
||||
'load_count': 0,
|
||||
'use_count': 0,
|
||||
'total_load_time': 0,
|
||||
'last_used': None,
|
||||
'avg_load_time': 0
|
||||
}
|
||||
|
||||
if preload:
|
||||
logger.info(f"模型 {model_id} 已注册并标记为预加载")
|
||||
# 启动预加载线程
|
||||
threading.Thread(target=self._preload_model, args=(model_id,), daemon=True).start()
|
||||
else:
|
||||
logger.info(f"模型 {model_id} 已注册")
|
||||
|
||||
return True
|
||||
|
||||
def get_model(self, model_id):
|
||||
"""
|
||||
获取模型,如果未加载则加载
|
||||
|
||||
参数:
|
||||
model_id: 模型唯一标识符
|
||||
|
||||
返回:
|
||||
加载好的模型对象
|
||||
"""
|
||||
if model_id not in self.preload_config:
|
||||
raise ValueError(f"模型 {model_id} 未注册")
|
||||
|
||||
# 更新最后使用时间
|
||||
self.model_stats[model_id]['last_used'] = datetime.now()
|
||||
self.model_stats[model_id]['use_count'] += 1
|
||||
|
||||
# 检查模型是否已加载
|
||||
if model_id in self.loaded_models:
|
||||
# 将模型移至OrderedDict末尾,表示最近使用
|
||||
model = self.loaded_models.pop(model_id)
|
||||
self.loaded_models[model_id] = model
|
||||
logger.debug(f"使用已加载的模型: {model_id}")
|
||||
return model
|
||||
|
||||
# 获取模型加载锁,防止并发加载同一模型
|
||||
if model_id not in self.loading_locks:
|
||||
self.loading_locks[model_id] = threading.Lock()
|
||||
|
||||
# 加锁加载模型
|
||||
with self.loading_locks[model_id]:
|
||||
# 再次检查模型是否已被其他线程加载
|
||||
if model_id in self.loaded_models:
|
||||
return self.loaded_models[model_id]
|
||||
|
||||
# 检查是否有足够内存
|
||||
self._ensure_memory_available(self.preload_config[model_id]['model_size_gb'])
|
||||
|
||||
# 加载模型
|
||||
start_time = time.time()
|
||||
model = self._load_model(model_id)
|
||||
load_time = time.time() - start_time
|
||||
|
||||
# 更新统计
|
||||
self.model_stats[model_id]['load_count'] += 1
|
||||
self.model_stats[model_id]['total_load_time'] += load_time
|
||||
self.model_stats[model_id]['avg_load_time'] = (
|
||||
self.model_stats[model_id]['total_load_time'] /
|
||||
self.model_stats[model_id]['load_count']
|
||||
)
|
||||
|
||||
logger.info(f"模型 {model_id} 加载完成,耗时: {load_time:.2f}秒")
|
||||
|
||||
# 存储模型
|
||||
self.loaded_models[model_id] = model
|
||||
return model
|
||||
|
||||
def unload_model(self, model_id):
|
||||
"""
|
||||
手动卸载模型
|
||||
|
||||
参数:
|
||||
model_id: 模型唯一标识符
|
||||
"""
|
||||
if model_id in self.loaded_models:
|
||||
logger.info(f"手动卸载模型: {model_id}")
|
||||
del self.loaded_models[model_id]
|
||||
# 强制垃圾回收
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_model_stats(self):
|
||||
"""获取所有模型的使用统计"""
|
||||
result = {}
|
||||
for model_id, stats in self.model_stats.items():
|
||||
is_loaded = model_id in self.loaded_models
|
||||
result[model_id] = {
|
||||
**stats,
|
||||
'is_loaded': is_loaded,
|
||||
'preload': self.preload_config[model_id]['preload'],
|
||||
'model_size_gb': self.preload_config[model_id]['model_size_gb'],
|
||||
'device': self.preload_config[model_id]['device'],
|
||||
}
|
||||
return result
|
||||
|
||||
def preload_all(self):
|
||||
"""预加载所有标记为预加载的模型"""
|
||||
for model_id, config in self.preload_config.items():
|
||||
if config['preload'] and model_id not in self.loaded_models:
|
||||
threading.Thread(target=self._preload_model, args=(model_id,), daemon=True).start()
|
||||
|
||||
def _preload_model(self, model_id):
|
||||
"""预加载单个模型的内部方法"""
|
||||
try:
|
||||
logger.info(f"开始预加载模型: {model_id}")
|
||||
# 确保有足够内存
|
||||
self._ensure_memory_available(self.preload_config[model_id]['model_size_gb'])
|
||||
|
||||
# 加载模型
|
||||
start_time = time.time()
|
||||
model = self._load_model(model_id)
|
||||
load_time = time.time() - start_time
|
||||
|
||||
# 更新统计
|
||||
self.model_stats[model_id]['load_count'] += 1
|
||||
self.model_stats[model_id]['total_load_time'] += load_time
|
||||
self.model_stats[model_id]['avg_load_time'] = (
|
||||
self.model_stats[model_id]['total_load_time'] /
|
||||
self.model_stats[model_id]['load_count']
|
||||
)
|
||||
|
||||
# 存储模型
|
||||
self.loaded_models[model_id] = model
|
||||
logger.info(f"模型 {model_id} 预加载完成,耗时: {load_time:.2f}秒")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"预加载模型 {model_id} 失败: {e}")
|
||||
|
||||
def _load_model(self, model_id):
|
||||
"""加载模型的内部方法"""
|
||||
config = self.preload_config[model_id]
|
||||
|
||||
if config['load_function'] is not None:
|
||||
# 使用自定义加载函数
|
||||
return config['load_function'](config['model_path'], config['device'])
|
||||
|
||||
# 默认加载逻辑 - 根据文件扩展名确定加载方式
|
||||
model_path = config['model_path']
|
||||
device = config['device']
|
||||
|
||||
if model_path.endswith('.pt') or model_path.endswith('.pth'):
|
||||
# PyTorch模型
|
||||
return torch.load(model_path, map_location=device)
|
||||
elif model_path.endswith('.pkl'):
|
||||
# Pickle模型
|
||||
import pickle
|
||||
with open(model_path, 'rb') as f:
|
||||
return pickle.load(f)
|
||||
else:
|
||||
# 尝试作为目录加载
|
||||
if os.path.isdir(model_path):
|
||||
# 如果是目录,尝试加载预训练模型
|
||||
try:
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
model = AutoModel.from_pretrained(model_path)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||
return {'model': model.to(device), 'tokenizer': tokenizer}
|
||||
except ImportError:
|
||||
logger.error("transformers库未安装,无法加载预训练模型")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"加载预训练模型失败: {e}")
|
||||
raise
|
||||
|
||||
raise ValueError(f"无法确定如何加载模型: {model_path}")
|
||||
|
||||
def _ensure_memory_available(self, required_gb):
|
||||
"""确保有足够的内存来加载新模型"""
|
||||
# 如果当前没有加载的模型,直接返回
|
||||
if not self.loaded_models:
|
||||
return
|
||||
|
||||
# 计算当前已加载模型的总内存
|
||||
current_usage = sum(
|
||||
self.preload_config[model_id]['model_size_gb']
|
||||
for model_id in self.loaded_models
|
||||
)
|
||||
|
||||
# 如果添加新模型后超过限制,需要卸载一些模型
|
||||
while current_usage + required_gb > self.max_memory_usage and self.loaded_models:
|
||||
# 卸载最久未使用的模型(OrderedDict的首项)
|
||||
oldest_model_id, _ = next(iter(self.loaded_models.items()))
|
||||
# 检查是否是预加载且最近使用过的模型
|
||||
if (self.preload_config[oldest_model_id]['preload'] and
|
||||
self.model_stats[oldest_model_id]['last_used'] and
|
||||
(datetime.now() - self.model_stats[oldest_model_id]['last_used']) <
|
||||
timedelta(minutes=self.unload_timeout)):
|
||||
# 跳过预加载且最近使用过的模型
|
||||
# 将该模型移至OrderedDict末尾
|
||||
model = self.loaded_models.pop(oldest_model_id)
|
||||
self.loaded_models[oldest_model_id] = model
|
||||
# 如果所有模型都是预加载的且最近使用过,允许超过限制
|
||||
if len(self.loaded_models) <= 1:
|
||||
break
|
||||
continue
|
||||
|
||||
# 卸载模型并更新内存使用
|
||||
model_size = self.preload_config[oldest_model_id]['model_size_gb']
|
||||
del self.loaded_models[oldest_model_id]
|
||||
current_usage -= model_size
|
||||
logger.info(f"自动卸载模型以释放内存: {oldest_model_id} ({model_size}GB)")
|
||||
|
||||
# 强制垃圾回收
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def _monitor_models(self):
|
||||
"""监控并管理模型的内部线程方法"""
|
||||
while True:
|
||||
try:
|
||||
# 检查长时间未使用的非预加载模型
|
||||
current_time = datetime.now()
|
||||
for model_id in list(self.loaded_models.keys()):
|
||||
if (not self.preload_config[model_id]['preload'] and
|
||||
self.model_stats[model_id]['last_used'] and
|
||||
(current_time - self.model_stats[model_id]['last_used']) >
|
||||
timedelta(minutes=self.unload_timeout)):
|
||||
# 卸载长时间未使用的非预加载模型
|
||||
logger.info(f"卸载长时间未使用的模型: {model_id}")
|
||||
del self.loaded_models[model_id]
|
||||
# 强制垃圾回收
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# 每5分钟检查一次
|
||||
time.sleep(300)
|
||||
except Exception as e:
|
||||
logger.error(f"模型监控线程出错: {e}")
|
||||
time.sleep(300)
|
||||
|
||||
# 创建全局模型管理器实例
|
||||
model_manager = ModelManager()
|
||||
|
||||
# 注册示例函数
|
||||
def register_sentiment_model():
|
||||
"""注册情感分析模型示例"""
|
||||
from utils.model_loader import load_sentiment_model # 假设您有一个加载情感模型的函数
|
||||
|
||||
try:
|
||||
model_path = os.path.join('model', 'sentiment.marshal.3')
|
||||
model_manager.register_model(
|
||||
model_id='sentiment_basic',
|
||||
model_path=model_path,
|
||||
preload=True,
|
||||
model_size_gb=0.2,
|
||||
load_function=load_sentiment_model
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"注册情感分析模型失败: {e}")
|
||||
return False
|
||||
|
||||
def register_bert_model():
|
||||
"""注册BERT模型示例"""
|
||||
try:
|
||||
model_path = os.path.join('model_pro', 'bert_model')
|
||||
model_manager.register_model(
|
||||
model_id='bert_classifier',
|
||||
model_path=model_path,
|
||||
preload=True,
|
||||
model_size_gb=0.8
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"注册BERT模型失败: {e}")
|
||||
return False
|
||||
|
||||
# 自动注册常用模型(在导入时执行)
|
||||
try:
|
||||
register_sentiment_model()
|
||||
register_bert_model()
|
||||
except Exception as e:
|
||||
logger.error(f"自动注册模型失败: {e}")
|
||||
@@ -0,0 +1,494 @@
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from collections import defaultdict
|
||||
import random
|
||||
import torch
|
||||
import numpy as np
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Any, Tuple, Optional, Union, Callable
|
||||
|
||||
logger = logging.getLogger('model_router')
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
class ModelRouter:
|
||||
"""
|
||||
模型路由器 - 自动根据内容类型选择最优的AI模型
|
||||
|
||||
功能:
|
||||
1. 根据内容类型和任务需求,自动选择最合适的AI模型
|
||||
2. 支持多种模型供应商和模型类型
|
||||
3. 考虑性能、成本和准确度等因素进行智能路由
|
||||
4. 学习和适应用户偏好和使用模式
|
||||
5. 提供标准化的API接口,支持私有模型集成
|
||||
"""
|
||||
|
||||
_instance = None
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super(ModelRouter, cls).__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
# 支持的模型定义
|
||||
self.models = {
|
||||
# OpenAI 模型
|
||||
'gpt-4o-latest': {
|
||||
'provider': 'openai',
|
||||
'capabilities': {
|
||||
'text_analysis': 0.95,
|
||||
'sentiment_analysis': 0.92,
|
||||
'keyword_extraction': 0.90,
|
||||
'summarization': 0.93,
|
||||
'classification': 0.89,
|
||||
'chinese_text': 0.88
|
||||
},
|
||||
'cost_per_1k': 0.01,
|
||||
'max_tokens': 128000,
|
||||
'avg_latency': 2.5, # 秒
|
||||
'requires_api_key': 'OPENAI_API_KEY'
|
||||
},
|
||||
'gpt-4o-mini': {
|
||||
'provider': 'openai',
|
||||
'capabilities': {
|
||||
'text_analysis': 0.85,
|
||||
'sentiment_analysis': 0.82,
|
||||
'keyword_extraction': 0.80,
|
||||
'summarization': 0.84,
|
||||
'classification': 0.81,
|
||||
'chinese_text': 0.79
|
||||
},
|
||||
'cost_per_1k': 0.00015,
|
||||
'max_tokens': 4000,
|
||||
'avg_latency': 1.2,
|
||||
'requires_api_key': 'OPENAI_API_KEY'
|
||||
},
|
||||
'gpt-3.5-turbo': {
|
||||
'provider': 'openai',
|
||||
'capabilities': {
|
||||
'text_analysis': 0.75,
|
||||
'sentiment_analysis': 0.72,
|
||||
'keyword_extraction': 0.70,
|
||||
'summarization': 0.77,
|
||||
'classification': 0.73,
|
||||
'chinese_text': 0.65
|
||||
},
|
||||
'cost_per_1k': 0.0015,
|
||||
'max_tokens': 16000,
|
||||
'avg_latency': 0.8,
|
||||
'requires_api_key': 'OPENAI_API_KEY'
|
||||
},
|
||||
|
||||
# Claude 模型
|
||||
'claude-3.5-sonnet': {
|
||||
'provider': 'anthropic',
|
||||
'capabilities': {
|
||||
'text_analysis': 0.90,
|
||||
'sentiment_analysis': 0.91,
|
||||
'keyword_extraction': 0.85,
|
||||
'summarization': 0.92,
|
||||
'classification': 0.89,
|
||||
'chinese_text': 0.80
|
||||
},
|
||||
'cost_per_1k': 0.015,
|
||||
'max_tokens': 200000,
|
||||
'avg_latency': 2.8,
|
||||
'requires_api_key': 'ANTHROPIC_API_KEY'
|
||||
},
|
||||
'claude-3.5-haiku': {
|
||||
'provider': 'anthropic',
|
||||
'capabilities': {
|
||||
'text_analysis': 0.84,
|
||||
'sentiment_analysis': 0.83,
|
||||
'keyword_extraction': 0.79,
|
||||
'summarization': 0.85,
|
||||
'classification': 0.80,
|
||||
'chinese_text': 0.72
|
||||
},
|
||||
'cost_per_1k': 0.0025,
|
||||
'max_tokens': 200000,
|
||||
'avg_latency': 1.5,
|
||||
'requires_api_key': 'ANTHROPIC_API_KEY'
|
||||
},
|
||||
|
||||
# DeepSeek 模型
|
||||
'deepseek-chat': {
|
||||
'provider': 'deepseek',
|
||||
'capabilities': {
|
||||
'text_analysis': 0.82,
|
||||
'sentiment_analysis': 0.79,
|
||||
'keyword_extraction': 0.77,
|
||||
'summarization': 0.80,
|
||||
'classification': 0.77,
|
||||
'chinese_text': 0.90 # 特别好中文
|
||||
},
|
||||
'cost_per_1k': 0.002,
|
||||
'max_tokens': 4000,
|
||||
'avg_latency': 1.0,
|
||||
'requires_api_key': 'DEEPSEEK_API_KEY'
|
||||
},
|
||||
'deepseek-reasoner': {
|
||||
'provider': 'deepseek',
|
||||
'capabilities': {
|
||||
'text_analysis': 0.87,
|
||||
'sentiment_analysis': 0.75,
|
||||
'keyword_extraction': 0.76,
|
||||
'summarization': 0.78,
|
||||
'classification': 0.85,
|
||||
'chinese_text': 0.88
|
||||
},
|
||||
'cost_per_1k': 0.003,
|
||||
'max_tokens': 4000,
|
||||
'avg_latency': 1.8,
|
||||
'requires_api_key': 'DEEPSEEK_API_KEY'
|
||||
}
|
||||
}
|
||||
|
||||
# 任务类型定义
|
||||
self.task_types = {
|
||||
'sentiment_analysis': {
|
||||
'description': '情感分析',
|
||||
'key_capabilities': ['sentiment_analysis', 'text_analysis'],
|
||||
'example_prompt': '分析以下文本的情感倾向(积极、消极或中性)'
|
||||
},
|
||||
'topic_classification': {
|
||||
'description': '话题分类',
|
||||
'key_capabilities': ['classification', 'text_analysis'],
|
||||
'example_prompt': '将以下文本分类到最合适的话题类别'
|
||||
},
|
||||
'keyword_extraction': {
|
||||
'description': '关键词提取',
|
||||
'key_capabilities': ['keyword_extraction', 'text_analysis'],
|
||||
'example_prompt': '从以下文本中提取5个最重要的关键词'
|
||||
},
|
||||
'text_summarization': {
|
||||
'description': '文本摘要',
|
||||
'key_capabilities': ['summarization', 'text_analysis'],
|
||||
'example_prompt': '为以下文本生成一个简短的摘要'
|
||||
},
|
||||
'comprehensive_analysis': {
|
||||
'description': '综合分析',
|
||||
'key_capabilities': ['text_analysis', 'sentiment_analysis', 'keyword_extraction', 'summarization'],
|
||||
'example_prompt': '对以下文本进行全面分析,包括情感、关键词和主要观点'
|
||||
}
|
||||
}
|
||||
|
||||
# 用户偏好和使用历史
|
||||
self.usage_history = defaultdict(list)
|
||||
|
||||
# 模型可用性缓存
|
||||
self.available_models = {}
|
||||
|
||||
# 更新模型可用性
|
||||
self._update_available_models()
|
||||
|
||||
self._initialized = True
|
||||
logger.info("模型路由器初始化完成")
|
||||
|
||||
def _update_available_models(self):
|
||||
"""更新模型可用性"""
|
||||
self.available_models = {}
|
||||
|
||||
for model_id, model_info in self.models.items():
|
||||
# 检查API密钥是否可用
|
||||
api_key_env = model_info.get('requires_api_key')
|
||||
if api_key_env and os.getenv(api_key_env):
|
||||
self.available_models[model_id] = model_info
|
||||
|
||||
if not self.available_models:
|
||||
logger.warning("未找到可用的模型,请检查API密钥配置")
|
||||
else:
|
||||
logger.info(f"找到 {len(self.available_models)} 个可用模型")
|
||||
|
||||
def detect_content_type(self, text: str) -> Dict[str, float]:
|
||||
"""
|
||||
检测内容类型和特征
|
||||
|
||||
参数:
|
||||
text: 要分析的文本
|
||||
|
||||
返回:
|
||||
内容类型特征字典,键为特征名称,值为权重
|
||||
"""
|
||||
features = {
|
||||
'chinese_text': 0.0,
|
||||
'length': 0.0,
|
||||
'complexity': 0.0
|
||||
}
|
||||
|
||||
if not text:
|
||||
return features
|
||||
|
||||
# 检测中文比例
|
||||
chinese_chars = len(re.findall(r'[\u4e00-\u9fff]', text))
|
||||
total_chars = len(text)
|
||||
chinese_ratio = chinese_chars / total_chars if total_chars > 0 else 0
|
||||
|
||||
# 文本长度评分 (归一化至0-1)
|
||||
length_score = min(1.0, len(text) / 10000)
|
||||
|
||||
# 文本复杂度简单估计
|
||||
# 基于句子长度、词汇多样性等
|
||||
sentences = re.split(r'[.!?。!?]', text)
|
||||
avg_sentence_len = sum(len(s) for s in sentences) / len(sentences) if sentences else 0
|
||||
unique_words = len(set(re.findall(r'\w+', text.lower())))
|
||||
total_words = len(re.findall(r'\w+', text.lower()))
|
||||
|
||||
lexical_diversity = unique_words / total_words if total_words > 0 else 0
|
||||
complexity_score = (avg_sentence_len / 50 + lexical_diversity) / 2
|
||||
complexity_score = min(1.0, complexity_score)
|
||||
|
||||
features['chinese_text'] = chinese_ratio
|
||||
features['length'] = length_score
|
||||
features['complexity'] = complexity_score
|
||||
|
||||
return features
|
||||
|
||||
def select_model(self, text: str, task_type: str,
|
||||
optimize_for: str = 'balanced',
|
||||
exclude_models: List[str] = None) -> str:
|
||||
"""
|
||||
为给定文本和任务选择最合适的模型
|
||||
|
||||
参数:
|
||||
text: 要处理的文本
|
||||
task_type: 任务类型,如 'sentiment_analysis'
|
||||
optimize_for: 优化目标,可选值:'cost'(成本), 'performance'(性能), 'balanced'(平衡)
|
||||
exclude_models: 要排除的模型列表
|
||||
|
||||
返回:
|
||||
选择的模型ID
|
||||
"""
|
||||
if not self.available_models:
|
||||
logger.error("没有可用的模型,请检查API密钥配置")
|
||||
return None
|
||||
|
||||
if task_type not in self.task_types:
|
||||
logger.warning(f"未知的任务类型: {task_type},使用默认任务类型: 'comprehensive_analysis'")
|
||||
task_type = 'comprehensive_analysis'
|
||||
|
||||
# 获取内容特征
|
||||
content_features = self.detect_content_type(text)
|
||||
|
||||
# 获取任务关键能力
|
||||
task_capabilities = self.task_types[task_type]['key_capabilities']
|
||||
|
||||
# 计算每个模型的得分
|
||||
model_scores = {}
|
||||
exclude_models = exclude_models or []
|
||||
|
||||
for model_id, model_info in self.available_models.items():
|
||||
if model_id in exclude_models:
|
||||
continue
|
||||
|
||||
# 基于任务能力的得分
|
||||
capability_score = 0
|
||||
for capability in task_capabilities:
|
||||
capability_score += model_info['capabilities'].get(capability, 0)
|
||||
|
||||
capability_score /= len(task_capabilities)
|
||||
|
||||
# 基于内容特征的得分调整
|
||||
content_score = 1.0
|
||||
|
||||
# 如果有大量中文,增加中文能力的权重
|
||||
if content_features['chinese_text'] > 0.5:
|
||||
chinese_capability = model_info['capabilities'].get('chinese_text', 0)
|
||||
content_score *= (1.0 + chinese_capability) / 2
|
||||
|
||||
# 如果文本很长,检查模型的最大token限制
|
||||
if content_features['length'] > 0.7:
|
||||
max_tokens = model_info.get('max_tokens', 4000)
|
||||
if max_tokens < 10000:
|
||||
content_score *= 0.7 # 长文本降低短上下文模型的分数
|
||||
|
||||
# 如果文本很复杂,可能需要更强大的模型
|
||||
if content_features['complexity'] > 0.7:
|
||||
# 假设能力得分更高的模型更能处理复杂文本
|
||||
content_score *= (1.0 + capability_score) / 2
|
||||
|
||||
# 根据优化目标调整最终得分
|
||||
final_score = capability_score * content_score
|
||||
|
||||
if optimize_for == 'cost':
|
||||
# 成本越低,分数越高
|
||||
cost_factor = 1 - min(1.0, model_info.get('cost_per_1k', 0) / 0.03)
|
||||
final_score = final_score * 0.3 + cost_factor * 0.7
|
||||
elif optimize_for == 'performance':
|
||||
# 能力得分权重更高
|
||||
final_score = capability_score * 0.8 + content_score * 0.2
|
||||
# balanced 是默认值,不需要额外调整
|
||||
|
||||
model_scores[model_id] = final_score
|
||||
|
||||
if not model_scores:
|
||||
logger.warning("没有符合条件的可用模型")
|
||||
return list(self.available_models.keys())[0]
|
||||
|
||||
# 选择得分最高的模型
|
||||
selected_model = max(model_scores, key=model_scores.get)
|
||||
|
||||
# 记录使用历史
|
||||
self.usage_history[task_type].append({
|
||||
'model': selected_model,
|
||||
'timestamp': datetime.now().timestamp(),
|
||||
'score': model_scores[selected_model],
|
||||
'optimize_for': optimize_for
|
||||
})
|
||||
|
||||
logger.info(f"为任务 '{task_type}' 选择了模型: {selected_model} (得分: {model_scores[selected_model]:.4f})")
|
||||
return selected_model
|
||||
|
||||
def get_model_info(self, model_id: str) -> Dict:
|
||||
"""获取模型信息"""
|
||||
if model_id in self.models:
|
||||
return self.models[model_id]
|
||||
return None
|
||||
|
||||
def get_available_models(self, refresh: bool = False) -> Dict[str, Dict]:
|
||||
"""获取所有可用的模型"""
|
||||
if refresh:
|
||||
self._update_available_models()
|
||||
return self.available_models
|
||||
|
||||
def get_model_by_provider(self, provider: str, optimize_for: str = 'balanced') -> str:
|
||||
"""根据提供商获取推荐模型"""
|
||||
provider_models = {
|
||||
model_id: info for model_id, info in self.available_models.items()
|
||||
if info['provider'] == provider
|
||||
}
|
||||
|
||||
if not provider_models:
|
||||
logger.warning(f"未找到提供商 '{provider}' 的可用模型")
|
||||
return None
|
||||
|
||||
if optimize_for == 'cost':
|
||||
# 选择成本最低的模型
|
||||
return min(provider_models.items(), key=lambda x: x[1].get('cost_per_1k', float('inf')))[0]
|
||||
elif optimize_for == 'performance':
|
||||
# 选择性能最好的模型,简单取所有能力的平均值
|
||||
return max(provider_models.items(),
|
||||
key=lambda x: sum(x[1]['capabilities'].values()) / len(x[1]['capabilities']))[0]
|
||||
else:
|
||||
# 平衡模式,综合考虑成本和性能
|
||||
scores = {}
|
||||
for model_id, info in provider_models.items():
|
||||
perf_score = sum(info['capabilities'].values()) / len(info['capabilities'])
|
||||
cost_score = 1 - min(1.0, info.get('cost_per_1k', 0) / 0.03)
|
||||
scores[model_id] = perf_score * 0.5 + cost_score * 0.5
|
||||
|
||||
return max(scores, key=scores.get)
|
||||
|
||||
def get_task_types(self) -> Dict[str, Dict]:
|
||||
"""获取支持的任务类型"""
|
||||
return self.task_types
|
||||
|
||||
def register_custom_model(self, model_id: str, model_info: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
注册自定义模型
|
||||
|
||||
参数:
|
||||
model_id: 模型唯一标识符
|
||||
model_info: 模型信息字典,包含以下字段:
|
||||
- provider: 提供商名称
|
||||
- capabilities: 能力评分字典
|
||||
- cost_per_1k: 每千token的成本
|
||||
- max_tokens: 最大token限制
|
||||
- avg_latency: 平均延迟(秒)
|
||||
- requires_api_key: API密钥环境变量名
|
||||
|
||||
返回:
|
||||
是否注册成功
|
||||
"""
|
||||
# 验证必要字段
|
||||
required_fields = ['provider', 'capabilities', 'cost_per_1k', 'max_tokens']
|
||||
for field in required_fields:
|
||||
if field not in model_info:
|
||||
logger.error(f"注册模型失败: 缺少必要字段 '{field}'")
|
||||
return False
|
||||
|
||||
# 验证能力评分
|
||||
if not isinstance(model_info['capabilities'], dict):
|
||||
logger.error("注册模型失败: 'capabilities' 必须是字典")
|
||||
return False
|
||||
|
||||
# 添加模型
|
||||
self.models[model_id] = model_info
|
||||
|
||||
# 更新可用模型列表
|
||||
self._update_available_models()
|
||||
|
||||
logger.info(f"成功注册自定义模型: {model_id}")
|
||||
return True
|
||||
|
||||
# 创建全局模型路由器实例
|
||||
model_router = ModelRouter()
|
||||
|
||||
def select_model(text, task_type, optimize_for='balanced', exclude_models=None):
|
||||
"""选择最合适的模型"""
|
||||
return model_router.select_model(text, task_type, optimize_for, exclude_models)
|
||||
|
||||
def get_available_models(refresh=False):
|
||||
"""获取所有可用的模型"""
|
||||
return model_router.get_available_models(refresh)
|
||||
|
||||
def get_model_by_provider(provider, optimize_for='balanced'):
|
||||
"""根据提供商获取推荐模型"""
|
||||
return model_router.get_model_by_provider(provider, optimize_for)
|
||||
|
||||
def get_task_types():
|
||||
"""获取支持的任务类型"""
|
||||
return model_router.get_task_types()
|
||||
|
||||
def register_custom_model(model_id, model_info):
|
||||
"""注册自定义模型"""
|
||||
return model_router.register_custom_model(model_id, model_info)
|
||||
|
||||
# 示例用法
|
||||
if __name__ == "__main__":
|
||||
# 示例文本
|
||||
chinese_text = """
|
||||
近日,人工智能技术的发展引发广泛关注。
|
||||
专家指出,大型语言模型在自然语言处理领域取得了显著进展,
|
||||
但同时也带来了诸多伦理和安全问题。对此,业界呼吁加强监管,
|
||||
确保人工智能的发展能够造福人类社会。
|
||||
"""
|
||||
|
||||
english_text = """
|
||||
Recent developments in artificial intelligence technology have drawn widespread attention.
|
||||
Experts point out that large language models have made significant progress in the field of natural language processing,
|
||||
but also bring many ethical and security issues. In response, the industry calls for strengthened regulation
|
||||
to ensure that the development of artificial intelligence can benefit human society.
|
||||
"""
|
||||
|
||||
# 测试模型选择
|
||||
print("中文文本任务测试:")
|
||||
model_for_chinese = select_model(chinese_text, 'sentiment_analysis')
|
||||
print(f"选择的模型: {model_for_chinese}")
|
||||
|
||||
print("\n英文文本任务测试:")
|
||||
model_for_english = select_model(english_text, 'sentiment_analysis')
|
||||
print(f"选择的模型: {model_for_english}")
|
||||
|
||||
print("\n成本优化测试:")
|
||||
model_for_cost = select_model(chinese_text, 'sentiment_analysis', optimize_for='cost')
|
||||
print(f"选择的模型: {model_for_cost}")
|
||||
|
||||
print("\n性能优化测试:")
|
||||
model_for_perf = select_model(chinese_text, 'sentiment_analysis', optimize_for='performance')
|
||||
print(f"选择的模型: {model_for_perf}")
|
||||
|
||||
# 测试API提供商
|
||||
print("\n根据提供商获取模型:")
|
||||
for provider in ['openai', 'anthropic', 'deepseek']:
|
||||
model = get_model_by_provider(provider)
|
||||
if model:
|
||||
print(f"{provider}: {model}")
|
||||
else:
|
||||
print(f"{provider}: 无可用模型")
|
||||
@@ -0,0 +1,357 @@
|
||||
import re
|
||||
import json
|
||||
import os
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger('sensitive_filter')
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
class SensitiveDataFilter:
|
||||
"""
|
||||
敏感数据过滤器 - 用于检测和屏蔽输出内容中的敏感信息
|
||||
|
||||
功能:
|
||||
1. 自动识别并过滤手机号、邮箱、身份证号、信用卡号等敏感信息
|
||||
2. 支持自定义敏感信息模式和替换文本
|
||||
3. 提供批量处理和实时过滤功能
|
||||
"""
|
||||
|
||||
_instance = None
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super(SensitiveDataFilter, cls).__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
# 默认配置
|
||||
self.config = {
|
||||
'enabled': os.getenv('ENABLE_SENSITIVE_DATA_FILTER', 'true').lower() == 'true',
|
||||
'patterns': {
|
||||
'phone': r'\b1[3-9]\d{9}\b',
|
||||
'email': r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b',
|
||||
'id_card': r'\b[1-9]\d{5}(19|20)\d{2}(0[1-9]|1[0-2])(0[1-9]|[12]\d|3[01])\d{3}[\dXx]\b',
|
||||
'credit_card': r'\b\d{4}[ -]?\d{4}[ -]?\d{4}[ -]?\d{4}\b',
|
||||
'address': r'(北京|上海|广州|深圳|天津|重庆|南京|杭州|武汉|成都|西安)市.*?(路|街|道|巷).*?(号)'
|
||||
},
|
||||
'replacements': {
|
||||
'phone': '***********',
|
||||
'email': '******@*****',
|
||||
'id_card': '******************',
|
||||
'credit_card': '****************',
|
||||
'address': '[地址已隐藏]'
|
||||
}
|
||||
}
|
||||
|
||||
# 加载自定义配置
|
||||
self._load_config()
|
||||
|
||||
# 编译正则表达式
|
||||
self._compile_patterns()
|
||||
|
||||
self._initialized = True
|
||||
|
||||
logger.info("敏感数据过滤器初始化完成")
|
||||
if self.config['enabled']:
|
||||
logger.info(f"已启用以下类型的敏感数据过滤: {', '.join(self.config['patterns'].keys())}")
|
||||
else:
|
||||
logger.info("敏感数据过滤器已禁用")
|
||||
|
||||
def _load_config(self):
|
||||
"""加载自定义配置"""
|
||||
# 配置文件路径
|
||||
data_dir = os.getenv('DATA_DIR', 'data')
|
||||
config_path = os.path.join(data_dir, 'security', 'sensitive_filter.json')
|
||||
|
||||
if os.path.exists(config_path):
|
||||
try:
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
custom_config = json.load(f)
|
||||
|
||||
# 更新配置
|
||||
if 'enabled' in custom_config:
|
||||
self.config['enabled'] = custom_config['enabled']
|
||||
|
||||
if 'patterns' in custom_config:
|
||||
for key, pattern in custom_config['patterns'].items():
|
||||
self.config['patterns'][key] = pattern
|
||||
|
||||
if 'replacements' in custom_config:
|
||||
for key, replacement in custom_config['replacements'].items():
|
||||
self.config['replacements'][key] = replacement
|
||||
|
||||
logger.info(f"已加载自定义敏感数据过滤配置: {config_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"加载敏感数据过滤配置失败: {e}")
|
||||
|
||||
def _compile_patterns(self):
|
||||
"""编译正则表达式"""
|
||||
self.compiled_patterns = {}
|
||||
for key, pattern in self.config['patterns'].items():
|
||||
try:
|
||||
self.compiled_patterns[key] = re.compile(pattern)
|
||||
logger.debug(f"已编译敏感数据模式: {key} - {pattern}")
|
||||
except re.error as e:
|
||||
logger.error(f"编译敏感数据模式失败: {key} - {pattern}: {e}")
|
||||
|
||||
def filter_text(self, text):
|
||||
"""
|
||||
过滤文本中的敏感信息
|
||||
|
||||
参数:
|
||||
text: 要过滤的文本
|
||||
|
||||
返回:
|
||||
过滤后的文本
|
||||
"""
|
||||
if not self.config['enabled'] or not text:
|
||||
return text
|
||||
|
||||
filtered_text = text
|
||||
for key, pattern in self.compiled_patterns.items():
|
||||
replacement = self.config['replacements'].get(key, '[FILTERED]')
|
||||
filtered_text = pattern.sub(replacement, filtered_text)
|
||||
|
||||
return filtered_text
|
||||
|
||||
def filter_dict(self, data, *skip_keys):
|
||||
"""
|
||||
过滤字典中的敏感信息
|
||||
|
||||
参数:
|
||||
data: 要过滤的字典
|
||||
skip_keys: 要跳过的键(不进行过滤)
|
||||
|
||||
返回:
|
||||
过滤后的字典
|
||||
"""
|
||||
if not self.config['enabled'] or not data:
|
||||
return data
|
||||
|
||||
if not isinstance(data, dict):
|
||||
if isinstance(data, str):
|
||||
return self.filter_text(data)
|
||||
return data
|
||||
|
||||
filtered_data = {}
|
||||
for key, value in data.items():
|
||||
if key in skip_keys:
|
||||
filtered_data[key] = value
|
||||
continue
|
||||
|
||||
if isinstance(value, dict):
|
||||
filtered_data[key] = self.filter_dict(value, *skip_keys)
|
||||
elif isinstance(value, list):
|
||||
filtered_data[key] = [
|
||||
self.filter_dict(item, *skip_keys) if isinstance(item, (dict, list)) else
|
||||
self.filter_text(item) if isinstance(item, str) else item
|
||||
for item in value
|
||||
]
|
||||
elif isinstance(value, str):
|
||||
filtered_data[key] = self.filter_text(value)
|
||||
else:
|
||||
filtered_data[key] = value
|
||||
|
||||
return filtered_data
|
||||
|
||||
def filter_list(self, data, *skip_keys):
|
||||
"""
|
||||
过滤列表中的敏感信息
|
||||
|
||||
参数:
|
||||
data: 要过滤的列表
|
||||
skip_keys: 如果列表项是字典,要跳过的键
|
||||
|
||||
返回:
|
||||
过滤后的列表
|
||||
"""
|
||||
if not self.config['enabled'] or not data:
|
||||
return data
|
||||
|
||||
if not isinstance(data, list):
|
||||
if isinstance(data, dict):
|
||||
return self.filter_dict(data, *skip_keys)
|
||||
if isinstance(data, str):
|
||||
return self.filter_text(data)
|
||||
return data
|
||||
|
||||
return [
|
||||
self.filter_dict(item, *skip_keys) if isinstance(item, dict) else
|
||||
self.filter_list(item, *skip_keys) if isinstance(item, list) else
|
||||
self.filter_text(item) if isinstance(item, str) else item
|
||||
for item in data
|
||||
]
|
||||
|
||||
def is_sensitive_info(self, text, info_type=None):
|
||||
"""
|
||||
检查文本是否包含敏感信息
|
||||
|
||||
参数:
|
||||
text: 要检查的文本
|
||||
info_type: 指定要检查的敏感信息类型,如果为None则检查所有类型
|
||||
|
||||
返回:
|
||||
包含敏感信息返回True,否则返回False
|
||||
"""
|
||||
if not self.config['enabled'] or not text:
|
||||
return False
|
||||
|
||||
if info_type:
|
||||
if info_type not in self.compiled_patterns:
|
||||
logger.warning(f"未知的敏感信息类型: {info_type}")
|
||||
return False
|
||||
return bool(self.compiled_patterns[info_type].search(text))
|
||||
|
||||
for pattern in self.compiled_patterns.values():
|
||||
if pattern.search(text):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def get_sensitive_info_types(self, text):
|
||||
"""
|
||||
获取文本中包含的敏感信息类型
|
||||
|
||||
参数:
|
||||
text: 要检查的文本
|
||||
|
||||
返回:
|
||||
包含的敏感信息类型列表
|
||||
"""
|
||||
if not self.config['enabled'] or not text:
|
||||
return []
|
||||
|
||||
types = []
|
||||
for key, pattern in self.compiled_patterns.items():
|
||||
if pattern.search(text):
|
||||
types.append(key)
|
||||
|
||||
return types
|
||||
|
||||
def enable(self):
|
||||
"""启用敏感数据过滤器"""
|
||||
self.config['enabled'] = True
|
||||
logger.info("敏感数据过滤器已启用")
|
||||
|
||||
def disable(self):
|
||||
"""禁用敏感数据过滤器"""
|
||||
self.config['enabled'] = False
|
||||
logger.info("敏感数据过滤器已禁用")
|
||||
|
||||
def is_enabled(self):
|
||||
"""检查敏感数据过滤器是否启用"""
|
||||
return self.config['enabled']
|
||||
|
||||
def add_pattern(self, key, pattern, replacement='[FILTERED]'):
|
||||
"""
|
||||
添加自定义敏感信息模式
|
||||
|
||||
参数:
|
||||
key: 敏感信息类型标识
|
||||
pattern: 正则表达式字符串
|
||||
replacement: 替换文本
|
||||
"""
|
||||
try:
|
||||
# 测试是否是有效的正则表达式
|
||||
re.compile(pattern)
|
||||
|
||||
# 更新配置
|
||||
self.config['patterns'][key] = pattern
|
||||
self.config['replacements'][key] = replacement
|
||||
|
||||
# 重新编译正则表达式
|
||||
self._compile_patterns()
|
||||
|
||||
logger.info(f"已添加敏感信息模式: {key}")
|
||||
return True
|
||||
except re.error as e:
|
||||
logger.error(f"添加敏感信息模式失败: {key} - {pattern}: {e}")
|
||||
return False
|
||||
|
||||
def remove_pattern(self, key):
|
||||
"""
|
||||
移除敏感信息模式
|
||||
|
||||
参数:
|
||||
key: 敏感信息类型标识
|
||||
"""
|
||||
if key in self.config['patterns']:
|
||||
del self.config['patterns'][key]
|
||||
|
||||
if key in self.config['replacements']:
|
||||
del self.config['replacements'][key]
|
||||
|
||||
if key in self.compiled_patterns:
|
||||
del self.compiled_patterns[key]
|
||||
|
||||
logger.info(f"已移除敏感信息模式: {key}")
|
||||
return True
|
||||
|
||||
logger.warning(f"未找到敏感信息模式: {key}")
|
||||
return False
|
||||
|
||||
def save_config(self):
|
||||
"""保存当前配置到文件"""
|
||||
data_dir = os.getenv('DATA_DIR', 'data')
|
||||
security_dir = os.path.join(data_dir, 'security')
|
||||
os.makedirs(security_dir, exist_ok=True)
|
||||
|
||||
config_path = os.path.join(security_dir, 'sensitive_filter.json')
|
||||
|
||||
try:
|
||||
with open(config_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(self.config, f, ensure_ascii=False, indent=2)
|
||||
|
||||
logger.info(f"敏感数据过滤配置已保存到: {config_path}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"保存敏感数据过滤配置失败: {e}")
|
||||
return False
|
||||
|
||||
# 创建全局敏感数据过滤器实例
|
||||
sensitive_filter = SensitiveDataFilter()
|
||||
|
||||
# 提供便捷的过滤函数
|
||||
def filter_text(text):
|
||||
"""过滤文本中的敏感信息"""
|
||||
return sensitive_filter.filter_text(text)
|
||||
|
||||
def filter_dict(data, *skip_keys):
|
||||
"""过滤字典中的敏感信息"""
|
||||
return sensitive_filter.filter_dict(data, *skip_keys)
|
||||
|
||||
def filter_list(data, *skip_keys):
|
||||
"""过滤列表中的敏感信息"""
|
||||
return sensitive_filter.filter_list(data, *skip_keys)
|
||||
|
||||
def is_sensitive_info(text, info_type=None):
|
||||
"""检查文本是否包含敏感信息"""
|
||||
return sensitive_filter.is_sensitive_info(text, info_type)
|
||||
|
||||
# 示例用法
|
||||
if __name__ == "__main__":
|
||||
# 测试文本
|
||||
test_text = """
|
||||
联系人: 张三
|
||||
电话: 13812345678
|
||||
邮箱: zhangsan@example.com
|
||||
身份证: 110101199001011234
|
||||
地址: 北京市海淀区中关村大街20号
|
||||
信用卡: 6225 1234 5678 9012
|
||||
"""
|
||||
|
||||
# 过滤敏感信息
|
||||
filtered_text = filter_text(test_text)
|
||||
print("原始文本:")
|
||||
print(test_text)
|
||||
print("\n过滤后:")
|
||||
print(filtered_text)
|
||||
|
||||
# 检查敏感信息类型
|
||||
types = sensitive_filter.get_sensitive_info_types(test_text)
|
||||
print(f"\n包含的敏感信息类型: {types}")
|
||||
@@ -0,0 +1,837 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
import logging
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from utils.db_manager import DatabaseManager
|
||||
from utils.cache_manager import CacheManager
|
||||
from utils.model_router import ModelRouter
|
||||
from utils.sensitive_filter import SensitiveDataFilter
|
||||
from spider.weibo_crawler import WeiboCrawler
|
||||
from utils.ai_analyzer import AIAnalyzer
|
||||
|
||||
# 配置日志
|
||||
from utils.logger import setup_logger
|
||||
logger = setup_logger('workflow_engine', 'logs/workflow_engine.log')
|
||||
|
||||
class WorkflowEngine:
|
||||
"""工作流引擎 - 负责执行数据爬取和分析工作流"""
|
||||
|
||||
_instance = None
|
||||
_initialized = False
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super(WorkflowEngine, cls).__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
self.db = DatabaseManager()
|
||||
self.cache = CacheManager(memory_capacity=50, cache_duration=3600)
|
||||
self.model_router = ModelRouter()
|
||||
self.sensitive_filter = SensitiveDataFilter()
|
||||
self.executor = ThreadPoolExecutor(max_workers=5)
|
||||
self.running_tasks = {}
|
||||
|
||||
# 创建必要的目录
|
||||
self.data_dir = Path('data/workflow')
|
||||
self.data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self._initialized = True
|
||||
logger.info("工作流引擎初始化完成")
|
||||
|
||||
def execute_crawler_workflow(self, task_id, config):
|
||||
"""
|
||||
执行爬虫工作流
|
||||
|
||||
Args:
|
||||
task_id: 任务ID
|
||||
config: 爬虫配置
|
||||
"""
|
||||
logger.info(f"开始执行爬虫工作流: {task_id}")
|
||||
|
||||
try:
|
||||
# 更新任务状态为运行中
|
||||
self._update_task_status(task_id, 'running', 0)
|
||||
|
||||
# 创建爬虫实例
|
||||
crawler = WeiboCrawler()
|
||||
|
||||
# 设置爬虫参数
|
||||
source = config.get('source', 'hot_topics')
|
||||
depth = config.get('crawl_depth', 1)
|
||||
interval = config.get('interval', 5)
|
||||
filters = config.get('filters', {})
|
||||
|
||||
# 执行爬取
|
||||
result = crawler.crawl(
|
||||
source=source,
|
||||
depth=depth,
|
||||
interval=interval,
|
||||
filters=filters,
|
||||
callback=lambda progress: self._update_task_progress(task_id, progress)
|
||||
)
|
||||
|
||||
# 更新任务状态为已完成
|
||||
self._update_task_status(task_id, 'completed', 100, result=result)
|
||||
logger.info(f"爬虫工作流完成: {task_id}")
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"爬虫工作流出错: {str(e)}")
|
||||
logger.error(traceback.format_exc())
|
||||
self._update_task_status(task_id, 'failed', 0, error=str(e))
|
||||
return None
|
||||
|
||||
def execute_analysis_workflow(self, task_id, workflow):
|
||||
"""
|
||||
执行分析工作流
|
||||
|
||||
Args:
|
||||
task_id: 任务ID
|
||||
workflow: 工作流配置
|
||||
"""
|
||||
logger.info(f"开始执行分析工作流: {task_id}")
|
||||
|
||||
try:
|
||||
# 更新任务状态为运行中
|
||||
self._update_task_status(task_id, 'running', 0)
|
||||
|
||||
components = workflow.get('components', [])
|
||||
connections = workflow.get('connections', [])
|
||||
|
||||
# 验证工作流
|
||||
if not components or not connections:
|
||||
raise ValueError("工作流配置不完整,缺少组件或连接")
|
||||
|
||||
# 构建组件依赖图
|
||||
component_map, dependency_graph = self._build_dependency_graph(components, connections)
|
||||
|
||||
# 进行拓扑排序
|
||||
execution_order = self._topological_sort(dependency_graph)
|
||||
|
||||
# 执行组件
|
||||
result_map = {}
|
||||
total_components = len(execution_order)
|
||||
|
||||
for idx, component_id in enumerate(execution_order):
|
||||
component = component_map.get(component_id)
|
||||
if not component:
|
||||
continue
|
||||
|
||||
# 计算总体进度
|
||||
progress = int((idx / total_components) * 100)
|
||||
self._update_task_progress(task_id, progress)
|
||||
|
||||
# 获取输入数据
|
||||
input_data = self._get_component_input_data(component_id, connections, result_map)
|
||||
|
||||
# 执行组件
|
||||
result = self._execute_component(component, input_data)
|
||||
|
||||
# 存储结果
|
||||
result_map[component_id] = result
|
||||
|
||||
# 获取最终输出
|
||||
final_outputs = self._get_final_outputs(dependency_graph, result_map)
|
||||
|
||||
# 应用敏感信息过滤
|
||||
if final_outputs and self.sensitive_filter.is_enabled():
|
||||
if isinstance(final_outputs, dict):
|
||||
final_outputs = self.sensitive_filter.filter_dict(final_outputs)
|
||||
elif isinstance(final_outputs, list):
|
||||
final_outputs = self.sensitive_filter.filter_list(final_outputs)
|
||||
|
||||
# 更新任务状态为已完成
|
||||
self._update_task_status(task_id, 'completed', 100, result=final_outputs)
|
||||
logger.info(f"分析工作流完成: {task_id}")
|
||||
|
||||
return final_outputs
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"分析工作流出错: {str(e)}")
|
||||
logger.error(traceback.format_exc())
|
||||
self._update_task_status(task_id, 'failed', 0, error=str(e))
|
||||
return None
|
||||
|
||||
def start_workflow(self, workflow_type, config, template_id=None):
|
||||
"""
|
||||
异步启动工作流
|
||||
|
||||
Args:
|
||||
workflow_type: 工作流类型 (crawler/analysis)
|
||||
config: 工作流配置
|
||||
template_id: 关联的模板ID
|
||||
|
||||
Returns:
|
||||
task_id: 工作流任务ID
|
||||
"""
|
||||
# 生成任务ID
|
||||
task_id = str(uuid.uuid4())
|
||||
|
||||
# 保存任务信息到数据库
|
||||
conn = self.db.get_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
now = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO workflow_tasks
|
||||
(id, template_id, type, status, progress, config, created_at, updated_at)
|
||||
VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
|
||||
""",
|
||||
(
|
||||
task_id,
|
||||
template_id,
|
||||
workflow_type,
|
||||
'pending',
|
||||
0,
|
||||
json.dumps(config, ensure_ascii=False),
|
||||
now,
|
||||
now
|
||||
)
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
# 异步执行工作流
|
||||
if workflow_type == 'crawler':
|
||||
self.running_tasks[task_id] = self.executor.submit(
|
||||
self.execute_crawler_workflow, task_id, config
|
||||
)
|
||||
elif workflow_type == 'analysis':
|
||||
self.running_tasks[task_id] = self.executor.submit(
|
||||
self.execute_analysis_workflow, task_id, config
|
||||
)
|
||||
else:
|
||||
logger.error(f"未知的工作流类型: {workflow_type}")
|
||||
return None
|
||||
|
||||
return task_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"启动工作流失败: {str(e)}")
|
||||
conn.rollback()
|
||||
return None
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
def get_task_status(self, task_id):
|
||||
"""
|
||||
获取任务状态
|
||||
|
||||
Args:
|
||||
task_id: 任务ID
|
||||
|
||||
Returns:
|
||||
task: 任务信息
|
||||
"""
|
||||
# 先检查缓存
|
||||
cache_key = f"task_status:{task_id}"
|
||||
cached_task = self.cache.get(cache_key)
|
||||
if cached_task:
|
||||
return cached_task
|
||||
|
||||
# 从数据库获取
|
||||
conn = self.db.get_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
cursor.execute(
|
||||
"SELECT * FROM workflow_tasks WHERE id = %s",
|
||||
(task_id,)
|
||||
)
|
||||
task = cursor.fetchone()
|
||||
|
||||
if task:
|
||||
# 将JSON字符串转为Python对象
|
||||
if task.get('config'):
|
||||
task['config'] = json.loads(task['config'])
|
||||
if task.get('result'):
|
||||
task['result'] = json.loads(task['result'])
|
||||
|
||||
# 缓存结果
|
||||
self.cache.set(cache_key, task)
|
||||
|
||||
return task
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取任务状态失败: {str(e)}")
|
||||
return None
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
def cancel_task(self, task_id):
|
||||
"""
|
||||
取消任务
|
||||
|
||||
Args:
|
||||
task_id: 任务ID
|
||||
|
||||
Returns:
|
||||
success: 是否成功
|
||||
"""
|
||||
# 检查任务是否存在并正在运行
|
||||
if task_id in self.running_tasks:
|
||||
# 尝试取消任务
|
||||
future = self.running_tasks[task_id]
|
||||
if not future.done():
|
||||
future.cancel()
|
||||
|
||||
# 从运行列表中移除
|
||||
del self.running_tasks[task_id]
|
||||
|
||||
# 更新数据库状态
|
||||
conn = self.db.get_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
now = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||||
cursor.execute(
|
||||
"""
|
||||
UPDATE workflow_tasks
|
||||
SET status = %s, updated_at = %s
|
||||
WHERE id = %s
|
||||
""",
|
||||
('cancelled', now, task_id)
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
# 清理缓存
|
||||
cache_key = f"task_status:{task_id}"
|
||||
self.cache.delete(cache_key)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"取消任务失败: {str(e)}")
|
||||
conn.rollback()
|
||||
return False
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
def _update_task_status(self, task_id, status, progress, result=None, error=None):
|
||||
"""更新任务状态"""
|
||||
conn = self.db.get_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
now = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||||
update_fields = ["status = %s", "progress = %s", "updated_at = %s"]
|
||||
params = [status, progress, now]
|
||||
|
||||
# 添加开始时间
|
||||
if status == 'running' and progress == 0:
|
||||
update_fields.append("started_at = %s")
|
||||
params.append(now)
|
||||
|
||||
# 添加完成时间
|
||||
if status in ['completed', 'failed']:
|
||||
update_fields.append("completed_at = %s")
|
||||
params.append(now)
|
||||
|
||||
# 添加结果
|
||||
if result is not None:
|
||||
update_fields.append("result = %s")
|
||||
params.append(json.dumps(result, ensure_ascii=False))
|
||||
|
||||
# 添加错误
|
||||
if error is not None:
|
||||
update_fields.append("error = %s")
|
||||
params.append(error)
|
||||
|
||||
# 构建SQL
|
||||
sql = f"""
|
||||
UPDATE workflow_tasks
|
||||
SET {', '.join(update_fields)}
|
||||
WHERE id = %s
|
||||
"""
|
||||
params.append(task_id)
|
||||
|
||||
cursor.execute(sql, tuple(params))
|
||||
conn.commit()
|
||||
|
||||
# 清理缓存
|
||||
cache_key = f"task_status:{task_id}"
|
||||
self.cache.delete(cache_key)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"更新任务状态失败: {str(e)}")
|
||||
conn.rollback()
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
def _update_task_progress(self, task_id, progress):
|
||||
"""更新任务进度"""
|
||||
self._update_task_status(task_id, 'running', progress)
|
||||
|
||||
def _build_dependency_graph(self, components, connections):
|
||||
"""构建组件依赖图"""
|
||||
component_map = {comp['id']: comp for comp in components}
|
||||
dependency_graph = {comp['id']: [] for comp in components}
|
||||
|
||||
# 构建依赖关系
|
||||
for conn in connections:
|
||||
source = conn.get('source')
|
||||
target = conn.get('target')
|
||||
|
||||
if source and target and source in component_map and target in component_map:
|
||||
dependency_graph[target].append(source)
|
||||
|
||||
return component_map, dependency_graph
|
||||
|
||||
def _topological_sort(self, graph):
|
||||
"""拓扑排序,确定组件执行顺序"""
|
||||
visited = set()
|
||||
temp = set()
|
||||
order = []
|
||||
|
||||
def visit(node):
|
||||
if node in temp:
|
||||
raise ValueError(f"工作流存在循环依赖: {node}")
|
||||
if node in visited:
|
||||
return
|
||||
|
||||
temp.add(node)
|
||||
for neighbor in graph.get(node, []):
|
||||
visit(neighbor)
|
||||
|
||||
temp.remove(node)
|
||||
visited.add(node)
|
||||
order.append(node)
|
||||
|
||||
for node in graph:
|
||||
if node not in visited:
|
||||
visit(node)
|
||||
|
||||
return list(reversed(order))
|
||||
|
||||
def _get_component_input_data(self, component_id, connections, result_map):
|
||||
"""获取组件的输入数据"""
|
||||
input_data = {}
|
||||
|
||||
for conn in connections:
|
||||
if conn.get('target') == component_id:
|
||||
source_id = conn.get('source')
|
||||
if source_id in result_map:
|
||||
input_name = conn.get('targetInput', 'default')
|
||||
input_data[input_name] = result_map[source_id]
|
||||
|
||||
return input_data
|
||||
|
||||
def _execute_component(self, component, input_data):
|
||||
"""执行单个组件"""
|
||||
component_type = component.get('type')
|
||||
config = component.get('config', {})
|
||||
|
||||
if component_type == 'data_source':
|
||||
return self._execute_data_source(config, input_data)
|
||||
elif component_type == 'preprocessing':
|
||||
return self._execute_preprocessing(config, input_data)
|
||||
elif component_type == 'model':
|
||||
return self._execute_model(config, input_data)
|
||||
elif component_type == 'visualization':
|
||||
return self._execute_visualization(config, input_data)
|
||||
else:
|
||||
logger.warning(f"未知的组件类型: {component_type}")
|
||||
return None
|
||||
|
||||
def _execute_data_source(self, config, input_data):
|
||||
"""执行数据源组件"""
|
||||
source_type = config.get('source_type')
|
||||
|
||||
if source_type == 'database':
|
||||
# 从数据库获取数据
|
||||
table = config.get('table')
|
||||
filters = config.get('filters', {})
|
||||
limit = config.get('limit', 1000)
|
||||
|
||||
query_conditions = []
|
||||
query_params = []
|
||||
|
||||
for key, value in filters.items():
|
||||
if value:
|
||||
query_conditions.append(f"{key} = %s")
|
||||
query_params.append(value)
|
||||
|
||||
where_clause = f"WHERE {' AND '.join(query_conditions)}" if query_conditions else ""
|
||||
|
||||
sql = f"SELECT * FROM {table} {where_clause} LIMIT {limit}"
|
||||
|
||||
conn = self.db.get_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
cursor.execute(sql, tuple(query_params))
|
||||
return cursor.fetchall()
|
||||
except Exception as e:
|
||||
logger.error(f"数据库查询出错: {str(e)}")
|
||||
return []
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
elif source_type == 'file':
|
||||
# 从文件加载数据
|
||||
file_path = config.get('file_path')
|
||||
if not file_path or not os.path.exists(file_path):
|
||||
return []
|
||||
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
if file_path.endswith('.json'):
|
||||
return json.load(f)
|
||||
else:
|
||||
return f.read()
|
||||
except Exception as e:
|
||||
logger.error(f"文件读取出错: {str(e)}")
|
||||
return []
|
||||
|
||||
elif source_type == 'api':
|
||||
# 这里需要实现API调用逻辑
|
||||
# 由于涉及复杂的HTTP请求,暂不实现
|
||||
logger.warning("API数据源暂未实现")
|
||||
return []
|
||||
|
||||
else:
|
||||
logger.warning(f"未知的数据源类型: {source_type}")
|
||||
return []
|
||||
|
||||
def _execute_preprocessing(self, config, input_data):
|
||||
"""执行数据预处理组件"""
|
||||
preprocessing_type = config.get('preprocessing_type')
|
||||
data = input_data.get('default', [])
|
||||
|
||||
if not data:
|
||||
return []
|
||||
|
||||
if preprocessing_type == 'filter':
|
||||
# 过滤数据
|
||||
field = config.get('field')
|
||||
value = config.get('value')
|
||||
operator = config.get('operator', 'eq')
|
||||
|
||||
if not field:
|
||||
return data
|
||||
|
||||
result = []
|
||||
for item in data:
|
||||
if operator == 'eq' and item.get(field) == value:
|
||||
result.append(item)
|
||||
elif operator == 'neq' and item.get(field) != value:
|
||||
result.append(item)
|
||||
elif operator == 'contains' and value in str(item.get(field, '')):
|
||||
result.append(item)
|
||||
elif operator == 'not_contains' and value not in str(item.get(field, '')):
|
||||
result.append(item)
|
||||
|
||||
return result
|
||||
|
||||
elif preprocessing_type == 'sort':
|
||||
# 排序数据
|
||||
field = config.get('field')
|
||||
order = config.get('order', 'asc')
|
||||
|
||||
if not field:
|
||||
return data
|
||||
|
||||
return sorted(
|
||||
data,
|
||||
key=lambda x: x.get(field, ''),
|
||||
reverse=(order == 'desc')
|
||||
)
|
||||
|
||||
elif preprocessing_type == 'aggregate':
|
||||
# 聚合数据
|
||||
group_by = config.get('group_by')
|
||||
aggregate_field = config.get('aggregate_field')
|
||||
aggregate_type = config.get('aggregate_type', 'count')
|
||||
|
||||
if not group_by:
|
||||
return data
|
||||
|
||||
result = {}
|
||||
for item in data:
|
||||
key = item.get(group_by)
|
||||
if key not in result:
|
||||
result[key] = {
|
||||
'count': 0,
|
||||
'sum': 0,
|
||||
'values': []
|
||||
}
|
||||
|
||||
result[key]['count'] += 1
|
||||
|
||||
if aggregate_field:
|
||||
value = item.get(aggregate_field, 0)
|
||||
if isinstance(value, (int, float)):
|
||||
result[key]['sum'] += value
|
||||
result[key]['values'].append(value)
|
||||
|
||||
# 计算最终结果
|
||||
final_result = []
|
||||
for key, values in result.items():
|
||||
item = {group_by: key}
|
||||
|
||||
if aggregate_type == 'count':
|
||||
item['value'] = values['count']
|
||||
elif aggregate_type == 'sum':
|
||||
item['value'] = values['sum']
|
||||
elif aggregate_type == 'avg':
|
||||
item['value'] = values['sum'] / values['count'] if values['count'] > 0 else 0
|
||||
|
||||
final_result.append(item)
|
||||
|
||||
return final_result
|
||||
|
||||
else:
|
||||
logger.warning(f"未知的预处理类型: {preprocessing_type}")
|
||||
return data
|
||||
|
||||
def _execute_model(self, config, input_data):
|
||||
"""执行模型组件"""
|
||||
model_type = config.get('model_type')
|
||||
data = input_data.get('default', [])
|
||||
|
||||
if not data:
|
||||
return []
|
||||
|
||||
analyzer = AIAnalyzer()
|
||||
|
||||
if model_type == 'sentiment':
|
||||
# 情感分析
|
||||
texts = []
|
||||
if isinstance(data, list):
|
||||
# 如果是列表,从指定字段获取文本
|
||||
field = config.get('text_field', 'content')
|
||||
texts = [item.get(field, '') for item in data if item.get(field)]
|
||||
elif isinstance(data, str):
|
||||
# 如果是字符串,直接使用
|
||||
texts = [data]
|
||||
|
||||
# 获取合适的模型
|
||||
model = self.model_router.select_model_for_text(texts[0] if texts else "", "sentiment")
|
||||
|
||||
# 执行分析
|
||||
results = []
|
||||
for text in texts:
|
||||
result = analyzer.analyze_sentiment(text, model=model)
|
||||
results.append(result)
|
||||
|
||||
# 如果输入是列表,将结果合并回原始数据
|
||||
if isinstance(data, list):
|
||||
field = config.get('text_field', 'content')
|
||||
for i, item in enumerate(data):
|
||||
if i < len(results) and item.get(field):
|
||||
item['sentiment'] = results[i]
|
||||
return data
|
||||
else:
|
||||
return results[0] if results else None
|
||||
|
||||
elif model_type == 'topic':
|
||||
# 主题分类
|
||||
texts = []
|
||||
if isinstance(data, list):
|
||||
field = config.get('text_field', 'content')
|
||||
texts = [item.get(field, '') for item in data if item.get(field)]
|
||||
elif isinstance(data, str):
|
||||
texts = [data]
|
||||
|
||||
# 获取合适的模型
|
||||
model = self.model_router.select_model_for_text(texts[0] if texts else "", "topic")
|
||||
|
||||
# 执行分析
|
||||
results = []
|
||||
for text in texts:
|
||||
result = analyzer.analyze_topic(text, model=model)
|
||||
results.append(result)
|
||||
|
||||
# 如果输入是列表,将结果合并回原始数据
|
||||
if isinstance(data, list):
|
||||
field = config.get('text_field', 'content')
|
||||
for i, item in enumerate(data):
|
||||
if i < len(results) and item.get(field):
|
||||
item['topic'] = results[i]
|
||||
return data
|
||||
else:
|
||||
return results[0] if results else None
|
||||
|
||||
elif model_type == 'keywords':
|
||||
# 关键词提取
|
||||
texts = []
|
||||
if isinstance(data, list):
|
||||
field = config.get('text_field', 'content')
|
||||
texts = [item.get(field, '') for item in data if item.get(field)]
|
||||
elif isinstance(data, str):
|
||||
texts = [data]
|
||||
|
||||
# 获取合适的模型
|
||||
model = self.model_router.select_model_for_text(texts[0] if texts else "", "keyword")
|
||||
|
||||
# 执行分析
|
||||
results = []
|
||||
for text in texts:
|
||||
result = analyzer.extract_keywords(text, model=model)
|
||||
results.append(result)
|
||||
|
||||
# 如果输入是列表,将结果合并回原始数据
|
||||
if isinstance(data, list):
|
||||
field = config.get('text_field', 'content')
|
||||
for i, item in enumerate(data):
|
||||
if i < len(results) and item.get(field):
|
||||
item['keywords'] = results[i]
|
||||
return data
|
||||
else:
|
||||
return results[0] if results else None
|
||||
|
||||
elif model_type == 'summarize':
|
||||
# 文本摘要
|
||||
texts = []
|
||||
if isinstance(data, list):
|
||||
field = config.get('text_field', 'content')
|
||||
texts = [item.get(field, '') for item in data if item.get(field)]
|
||||
elif isinstance(data, str):
|
||||
texts = [data]
|
||||
|
||||
# 获取合适的模型
|
||||
model = self.model_router.select_model_for_text(texts[0] if texts else "", "summarization")
|
||||
|
||||
# 执行分析
|
||||
results = []
|
||||
for text in texts:
|
||||
result = analyzer.summarize_text(text, model=model)
|
||||
results.append(result)
|
||||
|
||||
# 如果输入是列表,将结果合并回原始数据
|
||||
if isinstance(data, list):
|
||||
field = config.get('text_field', 'content')
|
||||
for i, item in enumerate(data):
|
||||
if i < len(results) and item.get(field):
|
||||
item['summary'] = results[i]
|
||||
return data
|
||||
else:
|
||||
return results[0] if results else None
|
||||
|
||||
else:
|
||||
logger.warning(f"未知的模型类型: {model_type}")
|
||||
return data
|
||||
|
||||
def _execute_visualization(self, config, input_data):
|
||||
"""执行可视化组件"""
|
||||
visualization_type = config.get('visualization_type')
|
||||
data = input_data.get('default', [])
|
||||
|
||||
if not data:
|
||||
return {}
|
||||
|
||||
if visualization_type == 'chart':
|
||||
# 图表可视化
|
||||
chart_type = config.get('chart_type', 'bar')
|
||||
x_field = config.get('x_field')
|
||||
y_field = config.get('y_field')
|
||||
title = config.get('title', '数据可视化')
|
||||
|
||||
if not x_field or not y_field:
|
||||
return {'error': '缺少x或y字段'}
|
||||
|
||||
# 提取数据
|
||||
chart_data = {
|
||||
'type': chart_type,
|
||||
'title': title,
|
||||
'xAxis': {'type': 'category', 'data': []},
|
||||
'yAxis': {'type': 'value'},
|
||||
'series': [{'data': []}]
|
||||
}
|
||||
|
||||
for item in data:
|
||||
x_value = item.get(x_field)
|
||||
y_value = item.get(y_field)
|
||||
|
||||
if x_value is not None and y_value is not None:
|
||||
chart_data['xAxis']['data'].append(x_value)
|
||||
chart_data['series'][0]['data'].append(y_value)
|
||||
|
||||
return chart_data
|
||||
|
||||
elif visualization_type == 'table':
|
||||
# 表格可视化
|
||||
columns = config.get('columns', [])
|
||||
title = config.get('title', '数据表格')
|
||||
|
||||
# 如果没有指定列,使用数据中的所有字段
|
||||
if not columns and isinstance(data, list) and data:
|
||||
columns = list(data[0].keys())
|
||||
|
||||
# 构建表格数据
|
||||
table_data = {
|
||||
'type': 'table',
|
||||
'title': title,
|
||||
'columns': columns,
|
||||
'data': data
|
||||
}
|
||||
|
||||
return table_data
|
||||
|
||||
elif visualization_type == 'wordcloud':
|
||||
# 词云可视化
|
||||
word_field = config.get('word_field')
|
||||
value_field = config.get('value_field')
|
||||
title = config.get('title', '词云图')
|
||||
|
||||
if not word_field:
|
||||
return {'error': '缺少词字段'}
|
||||
|
||||
# 构建词云数据
|
||||
wordcloud_data = {
|
||||
'type': 'wordcloud',
|
||||
'title': title,
|
||||
'data': []
|
||||
}
|
||||
|
||||
for item in data:
|
||||
word = item.get(word_field)
|
||||
value = item.get(value_field, 1)
|
||||
|
||||
if word:
|
||||
wordcloud_data['data'].append({
|
||||
'name': word,
|
||||
'value': value
|
||||
})
|
||||
|
||||
return wordcloud_data
|
||||
|
||||
else:
|
||||
logger.warning(f"未知的可视化类型: {visualization_type}")
|
||||
return {}
|
||||
|
||||
def _get_final_outputs(self, dependency_graph, result_map):
|
||||
"""获取最终输出结果"""
|
||||
# 找出没有后继节点的叶子节点
|
||||
leaf_nodes = []
|
||||
all_targets = set()
|
||||
|
||||
for node, deps in dependency_graph.items():
|
||||
all_targets.update(deps)
|
||||
|
||||
for node in dependency_graph:
|
||||
if node not in all_targets:
|
||||
leaf_nodes.append(node)
|
||||
|
||||
# 收集所有叶子节点的结果
|
||||
outputs = {}
|
||||
for node in leaf_nodes:
|
||||
if node in result_map:
|
||||
outputs[node] = result_map[node]
|
||||
|
||||
return outputs
|
||||
@@ -0,0 +1,896 @@
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from flask import Blueprint, request, jsonify, current_app
|
||||
from utils.db_manager import DatabaseManager
|
||||
from utils.sensitive_filter import filter_dict
|
||||
from utils.cache_manager import CacheManager
|
||||
|
||||
workflow_bp = Blueprint('workflow', __name__, url_prefix='/api/workflow')
|
||||
logger = logging.getLogger('workflow_api')
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
# 缓存管理器
|
||||
workflow_cache = CacheManager(name="workflows", memory_capacity=100, cache_duration=1)
|
||||
|
||||
# 默认爬虫配置模板
|
||||
DEFAULT_CRAWLER_TEMPLATES = [
|
||||
{
|
||||
"id": "default_weibo",
|
||||
"name": "微博热门话题",
|
||||
"description": "抓取微博热门话题及相关评论",
|
||||
"icon": "fab fa-weibo",
|
||||
"config": {
|
||||
"source": "weibo",
|
||||
"crawlDepth": 2,
|
||||
"interval": 3600,
|
||||
"maxRetries": 3,
|
||||
"timeout": 30,
|
||||
"maxConcurrent": 2,
|
||||
"userAgent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36",
|
||||
"filters": {
|
||||
"minComments": 10,
|
||||
"minLikes": 50,
|
||||
"excludeKeywords": []
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "weibo_trending",
|
||||
"name": "微博热搜榜",
|
||||
"description": "抓取微博热搜榜单内容",
|
||||
"icon": "fas fa-fire",
|
||||
"config": {
|
||||
"source": "weibo_trending",
|
||||
"crawlDepth": 1,
|
||||
"interval": 1800,
|
||||
"maxRetries": 3,
|
||||
"timeout": 20,
|
||||
"maxConcurrent": 1,
|
||||
"filters": {
|
||||
"topN": 50,
|
||||
"excludeKeywords": []
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
# 默认分析流程模板
|
||||
DEFAULT_ANALYSIS_TEMPLATES = [
|
||||
{
|
||||
"id": "sentiment_analysis",
|
||||
"name": "情感分析流程",
|
||||
"description": "对文本进行情感分析",
|
||||
"icon": "fas fa-smile",
|
||||
"components": [
|
||||
{
|
||||
"id": "data_source",
|
||||
"type": "data_source",
|
||||
"name": "数据源",
|
||||
"config": {
|
||||
"source_type": "database",
|
||||
"table": "comments",
|
||||
"filter": {
|
||||
"timeRange": "1d"
|
||||
}
|
||||
},
|
||||
"position": {"x": 100, "y": 100}
|
||||
},
|
||||
{
|
||||
"id": "text_preprocessing",
|
||||
"type": "preprocessing",
|
||||
"name": "文本预处理",
|
||||
"config": {
|
||||
"removeStopwords": True,
|
||||
"removeURLs": True,
|
||||
"removeEmojis": False
|
||||
},
|
||||
"position": {"x": 300, "y": 100}
|
||||
},
|
||||
{
|
||||
"id": "sentiment_model",
|
||||
"type": "model",
|
||||
"name": "情感分析模型",
|
||||
"config": {
|
||||
"model_type": "sentiment",
|
||||
"api": "openai",
|
||||
"optimize_for": "balanced"
|
||||
},
|
||||
"position": {"x": 500, "y": 100}
|
||||
},
|
||||
{
|
||||
"id": "visualization",
|
||||
"type": "visualization",
|
||||
"name": "可视化",
|
||||
"config": {
|
||||
"chart_type": "pie",
|
||||
"title": "情感分布"
|
||||
},
|
||||
"position": {"x": 700, "y": 100}
|
||||
}
|
||||
],
|
||||
"connections": [
|
||||
{"source": "data_source", "target": "text_preprocessing"},
|
||||
{"source": "text_preprocessing", "target": "sentiment_model"},
|
||||
{"source": "sentiment_model", "target": "visualization"}
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "topic_analysis",
|
||||
"name": "话题分析流程",
|
||||
"description": "对文本进行话题分类和关键词提取",
|
||||
"icon": "fas fa-tasks",
|
||||
"components": [
|
||||
{
|
||||
"id": "data_source",
|
||||
"type": "data_source",
|
||||
"name": "数据源",
|
||||
"config": {
|
||||
"source_type": "database",
|
||||
"table": "weibo_posts",
|
||||
"filter": {
|
||||
"timeRange": "7d"
|
||||
}
|
||||
},
|
||||
"position": {"x": 100, "y": 100}
|
||||
},
|
||||
{
|
||||
"id": "text_preprocessing",
|
||||
"type": "preprocessing",
|
||||
"name": "文本预处理",
|
||||
"config": {
|
||||
"removeStopwords": True,
|
||||
"removeURLs": True,
|
||||
"removeEmojis": True
|
||||
},
|
||||
"position": {"x": 300, "y": 100}
|
||||
},
|
||||
{
|
||||
"id": "topic_model",
|
||||
"type": "model",
|
||||
"name": "话题分类模型",
|
||||
"config": {
|
||||
"model_type": "topic_classification",
|
||||
"api": "deepseek",
|
||||
"optimize_for": "performance"
|
||||
},
|
||||
"position": {"x": 500, "y": 50}
|
||||
},
|
||||
{
|
||||
"id": "keyword_model",
|
||||
"type": "model",
|
||||
"name": "关键词提取模型",
|
||||
"config": {
|
||||
"model_type": "keyword_extraction",
|
||||
"api": "openai",
|
||||
"optimize_for": "balanced"
|
||||
},
|
||||
"position": {"x": 500, "y": 150}
|
||||
},
|
||||
{
|
||||
"id": "topic_viz",
|
||||
"type": "visualization",
|
||||
"name": "话题分布",
|
||||
"config": {
|
||||
"chart_type": "bar",
|
||||
"title": "话题分布"
|
||||
},
|
||||
"position": {"x": 700, "y": 50}
|
||||
},
|
||||
{
|
||||
"id": "keyword_viz",
|
||||
"type": "visualization",
|
||||
"name": "关键词云",
|
||||
"config": {
|
||||
"chart_type": "wordcloud",
|
||||
"title": "热门关键词"
|
||||
},
|
||||
"position": {"x": 700, "y": 150}
|
||||
}
|
||||
],
|
||||
"connections": [
|
||||
{"source": "data_source", "target": "text_preprocessing"},
|
||||
{"source": "text_preprocessing", "target": "topic_model"},
|
||||
{"source": "text_preprocessing", "target": "keyword_model"},
|
||||
{"source": "topic_model", "target": "topic_viz"},
|
||||
{"source": "keyword_model", "target": "keyword_viz"}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
# 默认可用组件
|
||||
AVAILABLE_COMPONENTS = {
|
||||
"data_source": [
|
||||
{
|
||||
"id": "database",
|
||||
"name": "数据库",
|
||||
"description": "从系统数据库获取数据",
|
||||
"config_schema": {
|
||||
"table": {"type": "string", "description": "数据表名", "required": True},
|
||||
"filter": {"type": "object", "description": "数据过滤条件"}
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "api",
|
||||
"name": "API接口",
|
||||
"description": "从外部API获取数据",
|
||||
"config_schema": {
|
||||
"url": {"type": "string", "description": "API URL", "required": True},
|
||||
"method": {"type": "string", "description": "请求方法", "default": "GET"},
|
||||
"headers": {"type": "object", "description": "请求头"},
|
||||
"params": {"type": "object", "description": "请求参数"}
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "csv",
|
||||
"name": "CSV文件",
|
||||
"description": "从CSV文件导入数据",
|
||||
"config_schema": {
|
||||
"file_path": {"type": "string", "description": "文件路径", "required": True},
|
||||
"encoding": {"type": "string", "description": "文件编码", "default": "utf-8"},
|
||||
"delimiter": {"type": "string", "description": "分隔符", "default": ","}
|
||||
}
|
||||
}
|
||||
],
|
||||
"preprocessing": [
|
||||
{
|
||||
"id": "text_preprocessing",
|
||||
"name": "文本预处理",
|
||||
"description": "清洗和规范化文本数据",
|
||||
"config_schema": {
|
||||
"removeStopwords": {"type": "boolean", "description": "去除停用词", "default": True},
|
||||
"removeURLs": {"type": "boolean", "description": "去除URL", "default": True},
|
||||
"removeEmojis": {"type": "boolean", "description": "去除表情符号", "default": False},
|
||||
"lowercase": {"type": "boolean", "description": "转为小写", "default": True}
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "tokenization",
|
||||
"name": "分词",
|
||||
"description": "将文本切分为词语或标记",
|
||||
"config_schema": {
|
||||
"method": {"type": "string", "description": "分词方法", "default": "jieba"},
|
||||
"pos_tagging": {"type": "boolean", "description": "进行词性标注", "default": False}
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "feature_extraction",
|
||||
"name": "特征提取",
|
||||
"description": "从文本提取数值特征",
|
||||
"config_schema": {
|
||||
"method": {"type": "string", "description": "特征提取方法", "default": "tfidf"},
|
||||
"max_features": {"type": "integer", "description": "最大特征数", "default": 1000}
|
||||
}
|
||||
}
|
||||
],
|
||||
"model": [
|
||||
{
|
||||
"id": "sentiment",
|
||||
"name": "情感分析",
|
||||
"description": "分析文本情感倾向",
|
||||
"config_schema": {
|
||||
"api": {"type": "string", "description": "使用的API", "default": "openai"},
|
||||
"model_type": {"type": "string", "description": "模型类型", "default": "sentiment_analysis"},
|
||||
"optimize_for": {"type": "string", "description": "优化目标", "default": "balanced"}
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "topic_classification",
|
||||
"name": "话题分类",
|
||||
"description": "对文本进行话题分类",
|
||||
"config_schema": {
|
||||
"api": {"type": "string", "description": "使用的API", "default": "deepseek"},
|
||||
"model_type": {"type": "string", "description": "模型类型", "default": "topic_classification"},
|
||||
"optimize_for": {"type": "string", "description": "优化目标", "default": "performance"}
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "keyword_extraction",
|
||||
"name": "关键词提取",
|
||||
"description": "从文本中提取关键词",
|
||||
"config_schema": {
|
||||
"api": {"type": "string", "description": "使用的API", "default": "openai"},
|
||||
"model_type": {"type": "string", "description": "模型类型", "default": "keyword_extraction"},
|
||||
"optimize_for": {"type": "string", "description": "优化目标", "default": "balanced"}
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "custom_ai",
|
||||
"name": "自定义AI模型",
|
||||
"description": "使用自定义AI模型进行分析",
|
||||
"config_schema": {
|
||||
"model_path": {"type": "string", "description": "模型路径", "required": True},
|
||||
"model_type": {"type": "string", "description": "模型类型", "required": True}
|
||||
}
|
||||
}
|
||||
],
|
||||
"visualization": [
|
||||
{
|
||||
"id": "line_chart",
|
||||
"name": "折线图",
|
||||
"description": "展示数据随时间的变化趋势",
|
||||
"config_schema": {
|
||||
"title": {"type": "string", "description": "图表标题", "default": "时间趋势"},
|
||||
"x_axis": {"type": "string", "description": "X轴字段", "default": "time"},
|
||||
"y_axis": {"type": "string", "description": "Y轴字段", "default": "value"},
|
||||
"color": {"type": "string", "description": "线条颜色", "default": "#1890ff"}
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "bar_chart",
|
||||
"name": "柱状图",
|
||||
"description": "展示不同类别的数据对比",
|
||||
"config_schema": {
|
||||
"title": {"type": "string", "description": "图表标题", "default": "数据对比"},
|
||||
"x_axis": {"type": "string", "description": "X轴字段", "default": "category"},
|
||||
"y_axis": {"type": "string", "description": "Y轴字段", "default": "value"}
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "pie_chart",
|
||||
"name": "饼图",
|
||||
"description": "展示数据的构成比例",
|
||||
"config_schema": {
|
||||
"title": {"type": "string", "description": "图表标题", "default": "比例分布"},
|
||||
"value_field": {"type": "string", "description": "值字段", "default": "value"},
|
||||
"label_field": {"type": "string", "description": "标签字段", "default": "label"}
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "wordcloud",
|
||||
"name": "词云图",
|
||||
"description": "直观展示文本中的高频词",
|
||||
"config_schema": {
|
||||
"title": {"type": "string", "description": "图表标题", "default": "关键词云"},
|
||||
"max_words": {"type": "integer", "description": "最大词数", "default": 100},
|
||||
"color_scheme": {"type": "string", "description": "配色方案", "default": "viridis"}
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "heatmap",
|
||||
"name": "热力图",
|
||||
"description": "展示数据的密度分布",
|
||||
"config_schema": {
|
||||
"title": {"type": "string", "description": "图表标题", "default": "热力分布"},
|
||||
"x_axis": {"type": "string", "description": "X轴字段", "default": "x"},
|
||||
"y_axis": {"type": "string", "description": "Y轴字段", "default": "y"},
|
||||
"value_field": {"type": "string", "description": "值字段", "default": "value"}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@workflow_bp.route('/crawler-templates', methods=['GET'])
|
||||
def get_crawler_templates():
|
||||
"""获取爬虫配置模板列表"""
|
||||
# 从缓存获取
|
||||
templates = workflow_cache.get('crawler_templates')
|
||||
if templates is None:
|
||||
# 从数据库获取用户定义的模板
|
||||
db = DatabaseManager.get_connection()
|
||||
cursor = db.cursor()
|
||||
cursor.execute("""
|
||||
SELECT id, name, description, icon, config
|
||||
FROM crawler_templates
|
||||
WHERE deleted = 0
|
||||
ORDER BY created_at DESC
|
||||
""")
|
||||
user_templates = cursor.fetchall()
|
||||
cursor.close()
|
||||
|
||||
# 结合默认模板
|
||||
templates = DEFAULT_CRAWLER_TEMPLATES + list(user_templates)
|
||||
|
||||
# 缓存结果
|
||||
workflow_cache.set('crawler_templates', templates)
|
||||
|
||||
return jsonify({
|
||||
'success': True,
|
||||
'data': filter_dict(templates)
|
||||
})
|
||||
|
||||
@workflow_bp.route('/crawler-templates/<template_id>', methods=['GET'])
|
||||
def get_crawler_template(template_id):
|
||||
"""获取指定爬虫配置模板"""
|
||||
# 查找默认模板
|
||||
for template in DEFAULT_CRAWLER_TEMPLATES:
|
||||
if template['id'] == template_id:
|
||||
return jsonify({
|
||||
'success': True,
|
||||
'data': filter_dict(template)
|
||||
})
|
||||
|
||||
# 从数据库查找用户模板
|
||||
db = DatabaseManager.get_connection()
|
||||
cursor = db.cursor()
|
||||
cursor.execute("""
|
||||
SELECT id, name, description, icon, config
|
||||
FROM crawler_templates
|
||||
WHERE id = %s AND deleted = 0
|
||||
""", (template_id,))
|
||||
template = cursor.fetchone()
|
||||
cursor.close()
|
||||
|
||||
if not template:
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'message': f"未找到模板: {template_id}"
|
||||
}), 404
|
||||
|
||||
return jsonify({
|
||||
'success': True,
|
||||
'data': filter_dict(template)
|
||||
})
|
||||
|
||||
@workflow_bp.route('/crawler-templates', methods=['POST'])
|
||||
def create_crawler_template():
|
||||
"""创建爬虫配置模板"""
|
||||
data = request.json
|
||||
required_fields = ['name', 'description', 'config']
|
||||
|
||||
# 验证必要字段
|
||||
for field in required_fields:
|
||||
if field not in data:
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'message': f"缺少必要字段: {field}"
|
||||
}), 400
|
||||
|
||||
# 生成ID
|
||||
template_id = f"template_{int(time.time())}_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
# 准备数据
|
||||
template = {
|
||||
'id': template_id,
|
||||
'name': data['name'],
|
||||
'description': data['description'],
|
||||
'icon': data.get('icon', 'fas fa-spider'),
|
||||
'config': data['config'],
|
||||
'created_at': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
|
||||
'updated_at': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
|
||||
'deleted': 0
|
||||
}
|
||||
|
||||
# 保存到数据库
|
||||
db = DatabaseManager.get_connection()
|
||||
cursor = db.cursor()
|
||||
try:
|
||||
cursor.execute("""
|
||||
INSERT INTO crawler_templates
|
||||
(id, name, description, icon, config, created_at, updated_at, deleted)
|
||||
VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
|
||||
""", (
|
||||
template['id'],
|
||||
template['name'],
|
||||
template['description'],
|
||||
template['icon'],
|
||||
json.dumps(template['config']),
|
||||
template['created_at'],
|
||||
template['updated_at'],
|
||||
template['deleted']
|
||||
))
|
||||
db.commit()
|
||||
|
||||
# 清除缓存
|
||||
workflow_cache.invalidate('crawler_templates')
|
||||
|
||||
return jsonify({
|
||||
'success': True,
|
||||
'data': filter_dict(template)
|
||||
}), 201
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"创建爬虫模板失败: {e}")
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'message': f"创建模板失败: {str(e)}"
|
||||
}), 500
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
@workflow_bp.route('/crawler-templates/<template_id>', methods=['PUT'])
|
||||
def update_crawler_template(template_id):
|
||||
"""更新爬虫配置模板"""
|
||||
data = request.json
|
||||
|
||||
# 验证模板是否存在
|
||||
db = DatabaseManager.get_connection()
|
||||
cursor = db.cursor()
|
||||
cursor.execute("""
|
||||
SELECT id FROM crawler_templates
|
||||
WHERE id = %s AND deleted = 0
|
||||
""", (template_id,))
|
||||
exists = cursor.fetchone()
|
||||
|
||||
if not exists:
|
||||
cursor.close()
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'message': f"未找到模板: {template_id}"
|
||||
}), 404
|
||||
|
||||
# 准备更新数据
|
||||
update_data = {}
|
||||
if 'name' in data:
|
||||
update_data['name'] = data['name']
|
||||
if 'description' in data:
|
||||
update_data['description'] = data['description']
|
||||
if 'icon' in data:
|
||||
update_data['icon'] = data['icon']
|
||||
if 'config' in data:
|
||||
update_data['config'] = json.dumps(data['config'])
|
||||
|
||||
update_data['updated_at'] = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||||
|
||||
# 构建SQL语句
|
||||
sql = "UPDATE crawler_templates SET "
|
||||
sql += ", ".join([f"{key} = %s" for key in update_data.keys()])
|
||||
sql += " WHERE id = %s"
|
||||
|
||||
# 执行更新
|
||||
try:
|
||||
cursor.execute(sql, list(update_data.values()) + [template_id])
|
||||
db.commit()
|
||||
|
||||
# 清除缓存
|
||||
workflow_cache.invalidate('crawler_templates')
|
||||
|
||||
return jsonify({
|
||||
'success': True,
|
||||
'message': "模板更新成功"
|
||||
})
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"更新爬虫模板失败: {e}")
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'message': f"更新模板失败: {str(e)}"
|
||||
}), 500
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
@workflow_bp.route('/crawler-templates/<template_id>', methods=['DELETE'])
|
||||
def delete_crawler_template(template_id):
|
||||
"""删除爬虫配置模板"""
|
||||
# 验证模板是否存在
|
||||
db = DatabaseManager.get_connection()
|
||||
cursor = db.cursor()
|
||||
cursor.execute("""
|
||||
SELECT id FROM crawler_templates
|
||||
WHERE id = %s AND deleted = 0
|
||||
""", (template_id,))
|
||||
exists = cursor.fetchone()
|
||||
|
||||
if not exists:
|
||||
cursor.close()
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'message': f"未找到模板: {template_id}"
|
||||
}), 404
|
||||
|
||||
# 软删除
|
||||
try:
|
||||
cursor.execute("""
|
||||
UPDATE crawler_templates
|
||||
SET deleted = 1, updated_at = %s
|
||||
WHERE id = %s
|
||||
""", (datetime.now().strftime('%Y-%m-%d %H:%M:%S'), template_id))
|
||||
db.commit()
|
||||
|
||||
# 清除缓存
|
||||
workflow_cache.invalidate('crawler_templates')
|
||||
|
||||
return jsonify({
|
||||
'success': True,
|
||||
'message': "模板删除成功"
|
||||
})
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"删除爬虫模板失败: {e}")
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'message': f"删除模板失败: {str(e)}"
|
||||
}), 500
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
@workflow_bp.route('/analysis-templates', methods=['GET'])
|
||||
def get_analysis_templates():
|
||||
"""获取分析流程模板列表"""
|
||||
# 从缓存获取
|
||||
templates = workflow_cache.get('analysis_templates')
|
||||
if templates is None:
|
||||
# 从数据库获取用户定义的模板
|
||||
db = DatabaseManager.get_connection()
|
||||
cursor = db.cursor()
|
||||
cursor.execute("""
|
||||
SELECT id, name, description, icon, components, connections
|
||||
FROM analysis_templates
|
||||
WHERE deleted = 0
|
||||
ORDER BY created_at DESC
|
||||
""")
|
||||
user_templates = cursor.fetchall()
|
||||
cursor.close()
|
||||
|
||||
# 结合默认模板
|
||||
templates = DEFAULT_ANALYSIS_TEMPLATES + list(user_templates)
|
||||
|
||||
# 缓存结果
|
||||
workflow_cache.set('analysis_templates', templates)
|
||||
|
||||
return jsonify({
|
||||
'success': True,
|
||||
'data': filter_dict(templates)
|
||||
})
|
||||
|
||||
@workflow_bp.route('/analysis-templates/<template_id>', methods=['GET'])
|
||||
def get_analysis_template(template_id):
|
||||
"""获取指定分析流程模板"""
|
||||
# 查找默认模板
|
||||
for template in DEFAULT_ANALYSIS_TEMPLATES:
|
||||
if template['id'] == template_id:
|
||||
return jsonify({
|
||||
'success': True,
|
||||
'data': filter_dict(template)
|
||||
})
|
||||
|
||||
# 从数据库查找用户模板
|
||||
db = DatabaseManager.get_connection()
|
||||
cursor = db.cursor()
|
||||
cursor.execute("""
|
||||
SELECT id, name, description, icon, components, connections
|
||||
FROM analysis_templates
|
||||
WHERE id = %s AND deleted = 0
|
||||
""", (template_id,))
|
||||
template = cursor.fetchone()
|
||||
cursor.close()
|
||||
|
||||
if not template:
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'message': f"未找到模板: {template_id}"
|
||||
}), 404
|
||||
|
||||
return jsonify({
|
||||
'success': True,
|
||||
'data': filter_dict(template)
|
||||
})
|
||||
|
||||
@workflow_bp.route('/analysis-templates', methods=['POST'])
|
||||
def create_analysis_template():
|
||||
"""创建分析流程模板"""
|
||||
data = request.json
|
||||
required_fields = ['name', 'description', 'components', 'connections']
|
||||
|
||||
# 验证必要字段
|
||||
for field in required_fields:
|
||||
if field not in data:
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'message': f"缺少必要字段: {field}"
|
||||
}), 400
|
||||
|
||||
# 生成ID
|
||||
template_id = f"template_{int(time.time())}_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
# 准备数据
|
||||
template = {
|
||||
'id': template_id,
|
||||
'name': data['name'],
|
||||
'description': data['description'],
|
||||
'icon': data.get('icon', 'fas fa-chart-line'),
|
||||
'components': data['components'],
|
||||
'connections': data['connections'],
|
||||
'created_at': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
|
||||
'updated_at': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
|
||||
'deleted': 0
|
||||
}
|
||||
|
||||
# 保存到数据库
|
||||
db = DatabaseManager.get_connection()
|
||||
cursor = db.cursor()
|
||||
try:
|
||||
cursor.execute("""
|
||||
INSERT INTO analysis_templates
|
||||
(id, name, description, icon, components, connections, created_at, updated_at, deleted)
|
||||
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)
|
||||
""", (
|
||||
template['id'],
|
||||
template['name'],
|
||||
template['description'],
|
||||
template['icon'],
|
||||
json.dumps(template['components']),
|
||||
json.dumps(template['connections']),
|
||||
template['created_at'],
|
||||
template['updated_at'],
|
||||
template['deleted']
|
||||
))
|
||||
db.commit()
|
||||
|
||||
# 清除缓存
|
||||
workflow_cache.invalidate('analysis_templates')
|
||||
|
||||
return jsonify({
|
||||
'success': True,
|
||||
'data': filter_dict(template)
|
||||
}), 201
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"创建分析模板失败: {e}")
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'message': f"创建模板失败: {str(e)}"
|
||||
}), 500
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
@workflow_bp.route('/analysis-templates/<template_id>', methods=['PUT'])
|
||||
def update_analysis_template(template_id):
|
||||
"""更新分析流程模板"""
|
||||
data = request.json
|
||||
|
||||
# 验证模板是否存在
|
||||
db = DatabaseManager.get_connection()
|
||||
cursor = db.cursor()
|
||||
cursor.execute("""
|
||||
SELECT id FROM analysis_templates
|
||||
WHERE id = %s AND deleted = 0
|
||||
""", (template_id,))
|
||||
exists = cursor.fetchone()
|
||||
|
||||
if not exists:
|
||||
cursor.close()
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'message': f"未找到模板: {template_id}"
|
||||
}), 404
|
||||
|
||||
# 准备更新数据
|
||||
update_data = {}
|
||||
if 'name' in data:
|
||||
update_data['name'] = data['name']
|
||||
if 'description' in data:
|
||||
update_data['description'] = data['description']
|
||||
if 'icon' in data:
|
||||
update_data['icon'] = data['icon']
|
||||
if 'components' in data:
|
||||
update_data['components'] = json.dumps(data['components'])
|
||||
if 'connections' in data:
|
||||
update_data['connections'] = json.dumps(data['connections'])
|
||||
|
||||
update_data['updated_at'] = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||||
|
||||
# 构建SQL语句
|
||||
sql = "UPDATE analysis_templates SET "
|
||||
sql += ", ".join([f"{key} = %s" for key in update_data.keys()])
|
||||
sql += " WHERE id = %s"
|
||||
|
||||
# 执行更新
|
||||
try:
|
||||
cursor.execute(sql, list(update_data.values()) + [template_id])
|
||||
db.commit()
|
||||
|
||||
# 清除缓存
|
||||
workflow_cache.invalidate('analysis_templates')
|
||||
|
||||
return jsonify({
|
||||
'success': True,
|
||||
'message': "模板更新成功"
|
||||
})
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"更新分析模板失败: {e}")
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'message': f"更新模板失败: {str(e)}"
|
||||
}), 500
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
@workflow_bp.route('/analysis-templates/<template_id>', methods=['DELETE'])
|
||||
def delete_analysis_template(template_id):
|
||||
"""删除分析流程模板"""
|
||||
# 验证模板是否存在
|
||||
db = DatabaseManager.get_connection()
|
||||
cursor = db.cursor()
|
||||
cursor.execute("""
|
||||
SELECT id FROM analysis_templates
|
||||
WHERE id = %s AND deleted = 0
|
||||
""", (template_id,))
|
||||
exists = cursor.fetchone()
|
||||
|
||||
if not exists:
|
||||
cursor.close()
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'message': f"未找到模板: {template_id}"
|
||||
}), 404
|
||||
|
||||
# 软删除
|
||||
try:
|
||||
cursor.execute("""
|
||||
UPDATE analysis_templates
|
||||
SET deleted = 1, updated_at = %s
|
||||
WHERE id = %s
|
||||
""", (datetime.now().strftime('%Y-%m-%d %H:%M:%S'), template_id))
|
||||
db.commit()
|
||||
|
||||
# 清除缓存
|
||||
workflow_cache.invalidate('analysis_templates')
|
||||
|
||||
return jsonify({
|
||||
'success': True,
|
||||
'message': "模板删除成功"
|
||||
})
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"删除分析模板失败: {e}")
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'message': f"删除模板失败: {str(e)}"
|
||||
}), 500
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
@workflow_bp.route('/components', methods=['GET'])
|
||||
def get_available_components():
|
||||
"""获取可用组件列表"""
|
||||
return jsonify({
|
||||
'success': True,
|
||||
'data': filter_dict(AVAILABLE_COMPONENTS)
|
||||
})
|
||||
|
||||
@workflow_bp.route('/run-workflow', methods=['POST'])
|
||||
def run_workflow():
|
||||
"""执行工作流"""
|
||||
data = request.json
|
||||
|
||||
# 验证必要字段
|
||||
if 'components' not in data or 'connections' not in data:
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'message': "缺少必要字段: components 或 connections"
|
||||
}), 400
|
||||
|
||||
# 这里是执行工作流逻辑的占位符
|
||||
# 实际实现需要根据组件类型和连接关系建立执行计划并执行
|
||||
|
||||
# 记录执行请求
|
||||
logger.info(f"收到工作流执行请求,组件数量: {len(data['components'])}, 连接数量: {len(data['connections'])}")
|
||||
|
||||
# 创建任务ID
|
||||
task_id = f"task_{int(time.time())}_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
# 返回任务ID
|
||||
return jsonify({
|
||||
'success': True,
|
||||
'message': "工作流执行请求已提交",
|
||||
'data': {
|
||||
'task_id': task_id,
|
||||
'status': 'pending'
|
||||
}
|
||||
})
|
||||
|
||||
@workflow_bp.route('/task-status/<task_id>', methods=['GET'])
|
||||
def get_task_status(task_id):
|
||||
"""获取任务执行状态"""
|
||||
# 这里是获取任务状态的占位符
|
||||
# 实际实现需要查询任务执行状态
|
||||
|
||||
# 示例状态
|
||||
status = {
|
||||
'task_id': task_id,
|
||||
'status': 'running',
|
||||
'progress': 45,
|
||||
'message': "正在执行数据预处理",
|
||||
'started_at': (datetime.now() - timedelta(minutes=2)).strftime('%Y-%m-%d %H:%M:%S'),
|
||||
'updated_at': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||||
}
|
||||
|
||||
return jsonify({
|
||||
'success': True,
|
||||
'data': status
|
||||
})
|
||||
Reference in New Issue
Block a user