Comprehensive security enhancement, fix race conditions and injection vulnerabilities.

This commit is contained in:
戒酒的李白
2025-03-08 00:17:42 +08:00
parent 5630b30002
commit f81a71e970
3 changed files with 451 additions and 344 deletions
+148 -195
View File
@@ -3,13 +3,20 @@ import re
import getpass import getpass
import pymysql import pymysql
import subprocess import subprocess
from flask import Flask, session, request, redirect, render_template from flask import Flask, session, request, redirect, render_template, jsonify
from apscheduler.schedulers.background import BackgroundScheduler from apscheduler.schedulers.background import BackgroundScheduler
from pytz import utc from pytz import utc
from datetime import datetime, timedelta from datetime import datetime, timedelta
import time import time
from utils.logger import app_logger as logging from utils.logger import app_logger as logging
from utils.db_manager import DatabaseManager from utils.db_manager import DatabaseManager
import secrets
from dotenv import load_dotenv
from functools import wraps
import bleach
# 加载环境变量
load_dotenv()
def get_db_connection_interactive(): def get_db_connection_interactive():
""" """
@@ -18,17 +25,17 @@ def get_db_connection_interactive():
""" """
print("请依次输入数据库连接信息(直接按回车使用默认值):") print("请依次输入数据库连接信息(直接按回车使用默认值):")
host = input(" 1. 主机 (默认: localhost): ") or "localhost" host = input(" 1. 主机 (默认: localhost): ") or os.getenv('DB_HOST', 'localhost')
port_str = input(" 2. 端口 (默认: 3306): ") or "3306" port_str = input(" 2. 端口 (默认: 3306): ") or os.getenv('DB_PORT', '3306')
try: try:
port = int(port_str) port = int(port_str)
except ValueError: except ValueError:
logging.warning("端口号无效,使用默认端口 3306。") logging.warning("端口号无效,使用默认端口 3306。")
port = 3306 port = 3306
user = input(" 3. 用户名 (默认: root): ") or "root" user = input(" 3. 用户名 (默认: root): ") or os.getenv('DB_USER', 'root')
password = getpass.getpass(" 4. 密码 (默认: 12345678): ") or "12345678" password = getpass.getpass(" 4. 密码: ") or os.getenv('DB_PASSWORD', '')
db_name = input(" 5. 数据库名 (默认: Weibo_PublicOpinion_AnalysisSystem): ") or "Weibo_PublicOpinion_AnalysisSystem" db_name = input(" 5. 数据库名 (默认: Weibo_PublicOpinion_AnalysisSystem): ") or os.getenv('DB_NAME', 'Weibo_PublicOpinion_AnalysisSystem')
logging.info(f"尝试连接到数据库: {user}@{host}:{port}/{db_name}") logging.info(f"尝试连接到数据库: {user}@{host}:{port}/{db_name}")
@@ -40,237 +47,183 @@ def get_db_connection_interactive():
password=password, password=password,
database=db_name, database=db_name,
charset='utf8mb4', charset='utf8mb4',
cursorclass=pymysql.cursors.DictCursor # 返回字典格式 cursorclass=pymysql.cursors.DictCursor,
ssl={'ssl': {'ca': os.getenv('DB_SSL_CA')}} if os.getenv('DB_SSL_CA') else None
) )
logging.info("数据库连接成功。") logging.info("数据库连接成功。")
return connection return connection
except pymysql.MySQLError as e: except pymysql.MySQLError as e:
logging.error(f"数据库连接失败: {e}") logging.error(f"数据库连接失败: {e}")
exit(1) raise
def initialize_database(connection, sql_file_path): def sanitize_input(text):
""" """清理用户输入,防止XSS攻击"""
执行 SQL 文件中的语句以初始化数据库。 if text is None:
return None
:param connection: 已建立的数据库连接 return bleach.clean(str(text), strip=True)
:param sql_file_path: SQL 文件的路径
"""
try:
with open(sql_file_path, 'r', encoding='utf8') as file:
sql_commands = file.read()
with connection.cursor() as cursor:
for statement in sql_commands.split(';'):
statement = statement.strip()
if statement:
cursor.execute(statement)
connection.commit()
logging.info("数据库初始化成功。")
except FileNotFoundError:
logging.error(f"SQL 文件未找到: {sql_file_path}")
exit(1)
except pymysql.MySQLError as e:
logging.error(f"执行 SQL 时出错: {e}")
connection.rollback()
exit(1)
except Exception as e:
logging.error(f"初始化数据库时出错: {e}")
connection.rollback()
exit(1)
def prompt_first_run(): def set_secure_headers(response):
""" """设置安全响应头"""
询问用户是否首次运行,需要初始化数据库。 response.headers['X-Content-Type-Options'] = 'nosniff'
response.headers['X-Frame-Options'] = 'SAMEORIGIN'
:return: BooleanTrue 表示需要初始化数据库 response.headers['X-XSS-Protection'] = '1; mode=block'
""" response.headers['Strict-Transport-Security'] = 'max-age=31536000; includeSubDomains'
while True: response.headers['Content-Security-Policy'] = "default-src 'self'"
choice = input("是否首次运行该项目,需要初始化数据库?(Y/n): ").strip().lower() return response
if choice in ['y', 'yes', '']:
return True
elif choice in ['n', 'no']:
return False
else:
print("请输入 Y 或 N。")
# 初始化 Flask 应用 # 初始化 Flask 应用
app = Flask(__name__) app = Flask(__name__)
app.secret_key = 'this is secret_key you know ?' # 设置 Flask 的密钥,用于 session 加密 app.secret_key = os.getenv('FLASK_SECRET_KEY', secrets.token_hex(32))
app.config['SESSION_COOKIE_SECURE'] = True
app.config['SESSION_COOKIE_HTTPONLY'] = True
app.config['SESSION_COOKIE_SAMESITE'] = 'Lax'
app.config['PERMANENT_SESSION_LIFETIME'] = timedelta(hours=2)
# 导入蓝图 # 导入蓝图
from views.page import page from views.page import page
from views.user import user from views.user import user
from views.spider_control import spider_bp from views.spider_control import spider_bp
app.register_blueprint(page.pb) # 注册页面蓝图 app.register_blueprint(page.pb)
app.register_blueprint(user.ub) # 注册用户蓝图 app.register_blueprint(user.ub)
app.register_blueprint(spider_bp) # 注册爬虫控制蓝图 app.register_blueprint(spider_bp)
# 首页路由,清空 session # 首页路由
@app.route('/') @app.route('/')
def hello_world(): def hello_world():
session.clear() # 清空 session,用户退出登录 session.clear()
return "Session Cleared" return redirect('/user/login')
# 中间件:处理请求前的逻辑 # 请求前中间件
@app.before_request @app.before_request
def before_request(): def before_request():
# 检查是否是HTTPS
if not request.is_secure and not app.debug:
url = request.url.replace('http://', 'https://', 1)
return redirect(url, code=301)
# 如果请求的是静态文件路径,允许访问 # 如果请求的是静态文件路径,允许访问
if request.path.startswith('/static'): if request.path.startswith('/static'):
return return
# 如果请求的是登录或注册页面,不需要会话验证 # 如果请求的是登录或注册页面,不需要会话验证
if request.path in ['/user/login', '/user/register']: if request.path in ['/user/login', '/user/register']:
return return
# 如果 session 中没有用户名,重定向到登录页面 # 验证会话
if not session.get('username'): if not session.get('username'):
return redirect('/user/login') return redirect('/user/login')
# 404 错误页面路由
@app.route('/<path:path>')
def catch_all(path):
return render_template('404.html') # 如果路径不存在,返回 404 页面
# 定义定时任务,运行爬虫脚本
def run_script():
current_dir = os.path.dirname(os.path.abspath(__file__)) # 获取当前脚本的目录
spider_script = os.path.join(current_dir, 'spider', 'main.py') # 爬虫脚本路径
# cutComments_script = os.path.join(current_dir, 'utils', 'cutComments.py') # 评论处理脚本路径
# cipingTotal_script = os.path.join(current_dir, 'utils', 'cipingTotal.py') # 评分处理脚本路径
# 定义所有要运行的脚本
scripts = [
("Spider Script", spider_script),
# ("Cut Comments Script", cutComments_script),
# ("Ciping Total Script", cipingTotal_script)
]
# 执行所有脚本
for script_name, script_path in scripts:
try:
logging.info(f"Running {script_name}...")
subprocess.run(['python', script_path], check=True) # 使用 subprocess 执行脚本
logging.info(f"{script_name} finished successfully.")
except subprocess.CalledProcessError as e:
logging.error(f"An error occurred while running {script_name}: {e}")
# 新增功能:动态调度爬虫脚本
def check_database_empty():
"""
检查数据库中的指定表是否为空。
:return: 如果表为空则返回 True,否则返回 False # 验证会话完整性
""" if 'client_info' not in session:
try: session.clear()
connection = pymysql.connect(**DB_CONFIG) return redirect('/user/login')
with connection.cursor() as cursor:
cursor.execute("SELECT COUNT(*) as count FROM article")
result = cursor.fetchone()
count = result['count'] if result and 'count' in result else 0
logging.info(f"数据库中共有 {count} 条记录。")
return count == 0
except pymysql.MySQLError as e:
logging.error(f"检查数据库失败: {e}")
return True # 连接失败时假设数据库为空,以防止阻塞
finally:
if 'connection' in locals():
connection.close()
def dynamic_crawl():
"""
执行爬取任务并根据爬取耗时和获取的数据量动态调度下次爬取时间。
"""
try:
start_time = time.time()
logging.info("开始爬取数据。")
run_script() # 执行爬虫脚本 # 验证客户端信息
current_client = {
end_time = time.time() 'ip': request.remote_addr,
duration = end_time - start_time # 爬取耗时 'user_agent': str(request.user_agent)
}
# 获取爬取后数据库中记录的数量作为数据量 stored_client = session.get('client_info', {})
try:
connection = pymysql.connect(**DB_CONFIG)
with connection.cursor() as cursor:
cursor.execute("SELECT COUNT(*) as count FROM article")
result = cursor.fetchone()
data_fetched = result['count'] if result and 'count' in result else 0
logging.info(f"爬取完成,耗时 {duration:.2f} 秒,数据库中共有 {data_fetched} 条记录。")
except pymysql.MySQLError as e:
logging.error(f"获取数据量失败: {e}")
data_fetched = 0
finally:
if 'connection' in locals():
connection.close()
# 根据爬取耗时和数据量调整下次爬取时间
base_interval = 5 * 60 * 60 # 5小时的基础时间间隔(秒)
if duration > 3600: # 爬取耗时超过1小时
next_interval = base_interval + duration
logging.info(f"检测到长时间爬取。下次爬取将在 {next_interval/3600:.2f} 小时后执行。")
elif data_fetched < 50: # 获取的数据量少于50条
next_interval = base_interval / 2
logging.info(f"获取数据量较少。下次爬取将在 {next_interval/60:.2f} 分钟后执行。")
else:
next_interval = base_interval
logging.info(f"标准爬取完成。下次爬取将在 {next_interval/3600:.2f} 小时后执行。")
# 安排下次爬取任务
scheduler.add_job(dynamic_crawl, 'date', run_date=datetime.now() + timedelta(seconds=next_interval), id='dynamic_crawl')
except Exception as e: if (current_client['ip'] != stored_client.get('ip') or
logging.error(f"动态爬取过程中发生错误: {e}") current_client['user_agent'] != stored_client.get('user_agent')):
session.clear()
return redirect('/user/login')
# 数据库配置,用于动态调度功能 # 响应后中间件
@app.after_request
def after_request(response):
return set_secure_headers(response)
# 错误处理
@app.errorhandler(404)
def not_found_error(error):
return render_template('404.html'), 404
@app.errorhandler(500)
def internal_error(error):
return render_template('500.html'), 500
@app.errorhandler(403)
def forbidden_error(error):
return render_template('403.html'), 403
@app.errorhandler(400)
def bad_request_error(error):
return render_template('400.html'), 400
# 数据库配置
DB_CONFIG = { DB_CONFIG = {
'host': 'localhost', 'host': os.getenv('DB_HOST', 'localhost'),
'user': 'root', 'user': os.getenv('DB_USER', 'root'),
'password': '12345678', 'password': os.getenv('DB_PASSWORD', ''),
'database': 'Weibo_PublicOpinion_AnalysisSystem', 'database': os.getenv('DB_NAME', 'Weibo_PublicOpinion_AnalysisSystem'),
'port': 3306, 'port': int(os.getenv('DB_PORT', '3306')),
'charset': 'utf8mb4' 'charset': 'utf8mb4',
'ssl': {'ca': os.getenv('DB_SSL_CA')} if os.getenv('DB_SSL_CA') else None
} }
# 初始化数据库管理器 # 初始化数据库管理器
DatabaseManager.initialize(DB_CONFIG) DatabaseManager.initialize(DB_CONFIG)
# 主程序入口
if __name__ == '__main__': if __name__ == '__main__':
# 检测是否需要初始化数据库 # 检测是否需要初始化数据库
if prompt_first_run():
# 获取数据库连接
connection = get_db_connection_interactive()
# 执行数据库初始化
sql_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'createTables.sql')
initialize_database(connection, sql_file)
# 关闭数据库连接
connection.close()
logging.info("数据库连接已关闭。")
# 设置定时任务,动态执行爬虫脚本
scheduler = BackgroundScheduler(timezone=utc) # 创建后台任务调度器
scheduler.start() # 启动调度器
# 初始化调度:如果数据库为空,立即爬取;否则,按照基础时间间隔安排首次爬取
if check_database_empty():
logging.info("数据库为空。立即开始初始爬取。")
dynamic_crawl()
else:
logging.info("数据库已有数据。安排首次爬取。")
base_interval = 5 * 60 * 60 # 5小时
scheduler.add_job(dynamic_crawl, 'date', run_date=datetime.now() + timedelta(seconds=base_interval), id='dynamic_crawl')
try: try:
app.run() # 启动 Flask 应用 if os.getenv('INITIALIZE_DB', 'false').lower() == 'true':
finally: connection = get_db_connection_interactive()
scheduler.shutdown() # 确保在应用关闭时关闭调度器 sql_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'createTables.sql')
initialize_database(connection, sql_file)
connection.close()
logging.info("数据库初始化完成。")
except Exception as e:
logging.error(f"数据库初始化失败: {e}")
exit(1)
# 设置日志记录,捕获应用的请求信息 # 设置定时任务
try:
scheduler = BackgroundScheduler(timezone=utc)
scheduler.start()
if check_database_empty():
logging.info("数据库为空。立即开始初始爬取。")
dynamic_crawl()
else:
logging.info("数据库已有数据。安排首次爬取。")
base_interval = int(os.getenv('CRAWL_INTERVAL', '18000')) # 默认5小时
scheduler.add_job(
dynamic_crawl,
'date',
run_date=datetime.now() + timedelta(seconds=base_interval),
id='dynamic_crawl'
)
# 启动应用
app.run(
host=os.getenv('FLASK_HOST', '127.0.0.1'),
port=int(os.getenv('FLASK_PORT', '5000')),
ssl_context='adhoc' if os.getenv('ENABLE_HTTPS', 'false').lower() == 'true' else None
)
except Exception as e:
logging.error(f"应用启动失败: {e}")
if 'scheduler' in locals():
scheduler.shutdown()
exit(1)
finally:
if 'scheduler' in locals():
scheduler.shutdown()
# 请求日志记录
@app.before_request @app.before_request
def log_request_info(): def log_request_info():
# 记录每次请求信息,便于调试和监控 # 记录请求信息,但排除敏感数据
logging.info(f"Request: {request.method} {request.path}") # 记录请求的方式(GET/POST)和路径 sanitized_headers = dict(request.headers)
if 'Authorization' in sanitized_headers:
sanitized_headers['Authorization'] = '[FILTERED]'
if 'Cookie' in sanitized_headers:
sanitized_headers['Cookie'] = '[FILTERED]'
logging.info(
f"Request: {request.method} {request.path}\n"
f"Remote IP: {request.remote_addr}\n"
f"Headers: {sanitized_headers}"
)
+40 -3
View File
@@ -1,3 +1,40 @@
from flask import render_template from flask import render_template, jsonify
def errorResponse(errorMsg): import bleach
return render_template('error.html',errorMsg=errorMsg) import re
def sanitize_error_message(message):
"""
清理和验证错误消息
"""
if not message:
return "发生未知错误"
# 移除任何敏感信息
message = re.sub(r'(password|token|key|secret)=[\w\-]+', r'\1=[FILTERED]', str(message))
# 清理HTML和特殊字符
message = bleach.clean(message, strip=True)
# 限制消息长度
return message[:200] if len(message) > 200 else message
def errorResponse(errorMsg, status_code=400):
"""
统一的错误响应处理
:param errorMsg: 错误消息
:param status_code: HTTP状态码
:return: 错误响应
"""
safe_message = sanitize_error_message(errorMsg)
if 'application/json' in request.headers.get('Accept', ''):
return jsonify({
'success': False,
'error': safe_message
}), status_code
return render_template(
'error.html',
errorMsg=safe_message,
status_code=status_code
), status_code
+263 -146
View File
@@ -1,4 +1,4 @@
from flask import Flask, session, render_template, redirect, Blueprint, request, jsonify from flask import Flask, session, render_template, redirect, Blueprint, request, jsonify, abort
from utils.mynlp import SnowNLP from utils.mynlp import SnowNLP
from utils.getHomePageData import * from utils.getHomePageData import *
from utils.getHotWordPageData import * from utils.getHotWordPageData import *
@@ -16,12 +16,60 @@ from sqlalchemy import create_engine
import asyncio import asyncio
import torch import torch
from BCAT_front.predict import model_manager from BCAT_front.predict import model_manager
from functools import wraps
import bleach
import re
from datetime import datetime, timedelta
pb = Blueprint('page', pb = Blueprint('page',
__name__, __name__,
url_prefix='/page', url_prefix='/page',
template_folder='templates') template_folder='templates')
def sanitize_input(text):
"""清理用户输入,防止XSS攻击"""
if text is None:
return None
return bleach.clean(str(text), strip=True)
def validate_csrf_token():
"""验证CSRF令牌"""
token = request.form.get('csrf_token')
stored_token = session.get('csrf_token')
if not token or not stored_token or token != stored_token:
return False
return True
def login_required(f):
@wraps(f)
def decorated_function(*args, **kwargs):
if 'username' not in session:
return redirect('/user/login')
return f(*args, **kwargs)
return decorated_function
def api_login_required(f):
@wraps(f)
def decorated_function(*args, **kwargs):
if 'username' not in session:
return jsonify({'error': 'Unauthorized'}), 401
return f(*args, **kwargs)
return decorated_function
def rate_limit(f):
@wraps(f)
def decorated_function(*args, **kwargs):
key = f"rate_limit:{request.remote_addr}:{f.__name__}"
current = int(redis_client.get(key) or 0)
if current >= 100: # 每分钟100次请求限制
return jsonify({'error': 'Too many requests'}), 429
pipe = redis_client.pipeline()
pipe.incr(key)
pipe.expire(key, 60) # 60秒后重置
pipe.execute()
return f(*args, **kwargs)
return decorated_function
# 设置设备 # 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -37,14 +85,22 @@ except Exception as e:
logging.error(f"模型加载失败: {e}") logging.error(f"模型加载失败: {e}")
# 数据库配置 # 数据库配置
DATABASE_URL = "sqlite:///ai_analysis.db" DATABASE_URL = os.getenv('DATABASE_URL', "sqlite:///ai_analysis.db")
engine = create_engine(DATABASE_URL) engine = create_engine(DATABASE_URL)
AIAnalysis.metadata.create_all(engine) AIAnalysis.metadata.create_all(engine)
def predict_sentiment(text): def predict_sentiment(text):
"""使用改进版模型预测单个文本的情感""" """使用改进版模型预测单个文本的情感"""
try: try:
predictions, probabilities = model_manager.predict_batch([text]) if not text or len(text.strip()) == 0:
return None, None
# 清理输入
cleaned_text = sanitize_input(text)
if not cleaned_text:
return None, None
predictions, probabilities = model_manager.predict_batch([cleaned_text])
if predictions is not None and len(predictions) > 0: if predictions is not None and len(predictions) > 0:
return predictions[0], probabilities[0][predictions[0]] return predictions[0], probabilities[0][predictions[0]]
return None, None return None, None
@@ -53,55 +109,70 @@ def predict_sentiment(text):
return None, None return None, None
@pb.route('/home') @pb.route('/home')
@login_required
def home(): def home():
username = session.get('username') try:
articleLenMax, likeCountMaxAuthorName, cityMax = getHomeTagsData() username = session.get('username')
commentsLikeCountTopFore = getHomeCommentsLikeCountTopFore() articleLenMax, likeCountMaxAuthorName, cityMax = getHomeTagsData()
X, Y = getHomeArticleCreatedAtChart() commentsLikeCountTopFore = getHomeCommentsLikeCountTopFore()
typeChart = getHomeTypeChart() X, Y = getHomeArticleCreatedAtChart()
createAtChart = getHomeCommentCreatedChart() typeChart = getHomeTypeChart()
# getUserNameWordCloud() createAtChart = getHomeCommentCreatedChart()
return render_template('index.html',
username=username, return render_template('index.html',
articleLenMax=articleLenMax, username=username,
likeCountMaxAuthorName=likeCountMaxAuthorName, articleLenMax=articleLenMax,
cityMax=cityMax, likeCountMaxAuthorName=likeCountMaxAuthorName,
commentsLikeCountTopFore=commentsLikeCountTopFore, cityMax=cityMax,
xData=X, commentsLikeCountTopFore=commentsLikeCountTopFore,
yData=Y, xData=X,
typeChart=typeChart, yData=Y,
createAtChart=createAtChart) typeChart=typeChart,
createAtChart=createAtChart)
except Exception as e:
logging.error(f"加载首页时发生错误: {e}")
return render_template('error.html', error_message="加载首页失败")
@pb.route('/hotWord') @pb.route('/hotWord')
@login_required
def hotWord(): def hotWord():
username = session.get('username') try:
hotWordList = getAllHotWords() username = session.get('username')
print(hotWordList) hotWordList = getAllHotWords()
defaultHotWord = hotWordList[0][0] if not hotWordList:
if request.args.get('hotWord'): return render_template('error.html', error_message="无法获取热词列表")
defaultHotWord = request.args.get('hotWord')
hotWordLen = getHotWordLen(defaultHotWord) defaultHotWord = sanitize_input(request.args.get('hotWord', hotWordList[0][0]))
X, Y = getHotWordPageCreatedAtCharData(defaultHotWord)
sentences = '' # 验证热词是否在列表中
value = SnowNLP(defaultHotWord).sentiments if not any(defaultHotWord in word for word in hotWordList):
if value == 0.5: return abort(400, "无效的热词")
sentences = '中性'
elif value > 0.5: hotWordLen = getHotWordLen(defaultHotWord)
sentences = '正面' X, Y = getHotWordPageCreatedAtCharData(defaultHotWord)
elif value < 0.5:
sentences = '负面' value = SnowNLP(defaultHotWord).sentiments
comments = getCommentFilterData(defaultHotWord) if value == 0.5:
return render_template('hotWord.html', sentences = '中性'
username=username, elif value > 0.5:
hotWordList=hotWordList, sentences = '正面'
defaultHotWord=defaultHotWord, elif value < 0.5:
hotWordLen=hotWordLen, sentences = '负面'
sentences=sentences,
xData=X, comments = getCommentFilterData(defaultHotWord)
yData=Y,
comments=comments) return render_template('hotWord.html',
username=username,
hotWordList=hotWordList,
defaultHotWord=defaultHotWord,
hotWordLen=hotWordLen,
sentences=sentences,
xData=X,
yData=Y,
comments=comments)
except Exception as e:
logging.error(f"加载热词页面时发生错误: {e}")
return render_template('error.html', error_message="加载热词页面失败")
@pb.route('/hotTopic') @pb.route('/hotTopic')
def hotTopic(): def hotTopic():
@@ -127,18 +198,21 @@ def hotTopic():
yData=Y, yData=Y,
comments=comments) comments=comments)
@pb.route('/tableData') @pb.route('/tableData')
@login_required
def tableData(): def tableData():
username = session.get('username') try:
defaultFlag = False username = session.get('username')
if request.args.get('flag'): defaultFlag = True defaultFlag = bool(request.args.get('flag', False))
tableData = getTableDataList(defaultFlag) tableData = getTableDataList(defaultFlag)
return render_template('tableData.html',
username=username, return render_template('tableData.html',
tableData=tableData, username=username,
defaultFlag=defaultFlag) tableData=tableData,
defaultFlag=defaultFlag)
except Exception as e:
logging.error(f"加载表格数据时发生错误: {e}")
return render_template('error.html', error_message="加载表格数据失败")
@pb.route('/articleChar') @pb.route('/articleChar')
def articleChar(): def articleChar():
@@ -160,63 +234,89 @@ def articleChar():
x2Data=x2Data, x2Data=x2Data,
y2Data=y2Data) y2Data=y2Data)
@pb.route('/ipChar') @pb.route('/ipChar')
@login_required
def ipChar(): def ipChar():
username = session.get('username') try:
articleRegionData = getIPByArticleRegion() username = session.get('username')
commentRegionData = getIPByCommentsRegion() articleRegionData = getIPByArticleRegion()
return render_template('ipChar.html', commentRegionData = getIPByCommentsRegion()
username=username,
articleRegionData=articleRegionData, return render_template('ipChar.html',
commentRegionData=commentRegionData) username=username,
articleRegionData=articleRegionData,
commentRegionData=commentRegionData)
except Exception as e:
logging.error(f"加载IP统计时发生错误: {e}")
return render_template('error.html', error_message="加载IP统计失败")
@pb.route('/commentChar') @pb.route('/commentChar')
@login_required
def commentChar(): def commentChar():
username = session.get('username') try:
X, Y = getCommentDataOne() username = session.get('username')
genderPieData = getCommentDataTwo() X, Y = getCommentDataOne()
return render_template('commentChar.html', genderPieData = getCommentDataTwo()
username=username,
xData=X, return render_template('commentChar.html',
yData=Y, username=username,
genderPieData=genderPieData) xData=X,
yData=Y,
genderPieData=genderPieData)
except Exception as e:
logging.error(f"加载评论统计时发生错误: {e}")
return render_template('error.html', error_message="加载评论统计失败")
@pb.route('/yuqingChar') @pb.route('/yuqingChar')
@login_required
def yuqingChar(): def yuqingChar():
username = session.get('username') try:
# 获取模型选择参数 username = session.get('username')
model_type = request.args.get('model', 'pro') # 默认使用改进模型 model_type = sanitize_input(request.args.get('model', 'pro'))
X, Y, biedata = getYuQingCharDataOne() # 验证模型类型
biedata1, biedata2 = getYuQingCharDataTwo(model_type) if model_type not in ['pro', 'basic']:
x1Data, y1Data = getYuQingCharDataThree() return abort(400, "无效的模型类型")
return render_template('yuqingChar.html',
username=username, X, Y, biedata = getYuQingCharDataOne()
xData=X, biedata1, biedata2 = getYuQingCharDataTwo(model_type)
yData=Y, x1Data, y1Data = getYuQingCharDataThree()
biedata=biedata,
biedata1=biedata1, return render_template('yuqingChar.html',
biedata2=biedata2, username=username,
x1Data=x1Data, xData=X,
y1Data=y1Data, yData=Y,
model_type=model_type) biedata=biedata,
biedata1=biedata1,
biedata2=biedata2,
x1Data=x1Data,
y1Data=y1Data,
model_type=model_type)
except Exception as e:
logging.error(f"加载舆情统计时发生错误: {e}")
return render_template('error.html', error_message="加载舆情统计失败")
@pb.route('/yuqingpredict') @pb.route('/yuqingpredict')
@login_required
def yuqingpredict(): def yuqingpredict():
try: try:
username = session.get('username') username = session.get('username')
TopicList = getAllTopicData() TopicList = getAllTopicData()
defaultTopic = TopicList[0][0] if not TopicList:
if request.args.get('Topic'): return render_template('error.html', error_message="无法获取话题列表")
defaultTopic = request.args.get('Topic')
defaultTopic = sanitize_input(request.args.get('Topic', TopicList[0][0]))
# 验证话题是否在列表中
if not any(defaultTopic in topic for topic in TopicList):
return abort(400, "无效的话题")
TopicLen = getTopicLen(defaultTopic) TopicLen = getTopicLen(defaultTopic)
X, Y = getTopicCreatedAtandpredictData(defaultTopic) X, Y = getTopicCreatedAtandpredictData(defaultTopic)
# 获取模型选择参数 model_type = sanitize_input(request.args.get('model', 'pro'))
model_type = request.args.get('model', 'pro') # 默认使用改进模型 if model_type not in ['pro', 'basic']:
return abort(400, "无效的模型类型")
# 尝试从缓存获取预测结果 # 尝试从缓存获取预测结果
cache_key = f"{defaultTopic}_{model_type}" cache_key = f"{defaultTopic}_{model_type}"
@@ -226,7 +326,6 @@ def yuqingpredict():
sentences = cached_result sentences = cached_result
else: else:
if model_type == 'basic': if model_type == 'basic':
# 使用基础模型(SnowNLP
value = SnowNLP(defaultTopic).sentiments value = SnowNLP(defaultTopic).sentiments
if value == 0.5: if value == 0.5:
sentences = '中性' sentences = '中性'
@@ -235,7 +334,6 @@ def yuqingpredict():
elif value < 0.5: elif value < 0.5:
sentences = '负面' sentences = '负面'
else: else:
# 使用改进模型
predicted_label, confidence = predict_sentiment(defaultTopic) predicted_label, confidence = predict_sentiment(defaultTopic)
if predicted_label is not None: if predicted_label is not None:
sentences = '良好' if predicted_label == 0 else '不良' sentences = '良好' if predicted_label == 0 else '不良'
@@ -248,26 +346,30 @@ def yuqingpredict():
prediction_cache.set(cache_key, sentences) prediction_cache.set(cache_key, sentences)
comments = getCommentFilterDataTopic(defaultTopic) comments = getCommentFilterDataTopic(defaultTopic)
return render_template('yuqingpredict.html', return render_template('yuqingpredict.html',
username=username, username=username,
hotWordList=TopicList, TopicList=TopicList,
defaultHotWord=defaultTopic, defaultTopic=defaultTopic,
hotWordLen=TopicLen, TopicLen=TopicLen,
sentences=sentences, sentences=sentences,
xData=X, xData=X,
yData=Y, yData=Y,
comments=comments, comments=comments,
model_type=model_type) model_type=model_type)
except Exception as e: except Exception as e:
logging.error(f"舆情预测页面渲染失败: {e}") logging.error(f"加载舆情预测时发生错误: {e}")
return render_template('error.html', error_message="加载舆情预测页面失败,请稍后重试") return render_template('error.html', error_message="加载舆情预测失败")
@pb.route('/articleCloud') @pb.route('/articleCloud')
@login_required
def articleCloud(): def articleCloud():
username = session.get('username') try:
return render_template('articleContentCloud.html', username=username) username = session.get('username')
return render_template('articleContentCloud.html', username=username)
except Exception as e:
logging.error(f"加载文章云图时发生错误: {e}")
return render_template('error.html', error_message="加载文章云图失败")
@pb.route('/page/index') @pb.route('/page/index')
def index(): def index():
@@ -306,15 +408,28 @@ def articleChar(id):
return render_template('error.html', error_message="加载文章详情失败") return render_template('error.html', error_message="加载文章详情失败")
@pb.route('/api/analyze_messages', methods=['POST']) @pb.route('/api/analyze_messages', methods=['POST'])
@api_login_required
@rate_limit
async def analyze_messages(): async def analyze_messages():
try: try:
# 获取请求参数 if not validate_csrf_token():
return jsonify({'error': 'Invalid CSRF token'}), 403
data = request.get_json() data = request.get_json()
batch_size = data.get('batch_size', 50) if not data:
model_type = data.get('model_type', 'gpt-3.5-turbo') return jsonify({'error': 'No data provided'}), 400
analysis_depth = data.get('analysis_depth', 'standard')
batch_size = min(int(data.get('batch_size', 50)), 100) # 限制批量大小
model_type = sanitize_input(data.get('model_type', 'gpt-3.5-turbo'))
analysis_depth = sanitize_input(data.get('analysis_depth', 'standard'))
# 验证参数
if model_type not in ['gpt-3.5-turbo', 'gpt-4']:
return jsonify({'error': 'Invalid model type'}), 400
if analysis_depth not in ['basic', 'standard', 'deep']:
return jsonify({'error': 'Invalid analysis depth'}), 400
# 获取最近的消息
messages = getRecentMessages(batch_size) messages = getRecentMessages(batch_size)
if not messages: if not messages:
return jsonify({ return jsonify({
@@ -322,7 +437,6 @@ async def analyze_messages():
'error': '没有找到需要分析的消息' 'error': '没有找到需要分析的消息'
}), 404 }), 404
# 调用AI进行分析
analysis_results = await ai_analyzer.analyze_messages( analysis_results = await ai_analyzer.analyze_messages(
messages=messages, messages=messages,
batch_size=batch_size, batch_size=batch_size,
@@ -336,22 +450,27 @@ async def analyze_messages():
'error': '分析过程中出现错误' 'error': '分析过程中出现错误'
}), 500 }), 500
# 保存到数据库 try:
with Session(engine) as session: with Session(engine) as session:
for result in analysis_results: for result in analysis_results:
analysis = AIAnalysis( analysis = AIAnalysis(
message_id=result['message_id'], message_id=result['message_id'],
sentiment=result['sentiment'], sentiment=result['sentiment'],
sentiment_score=float(result['sentiment_score']), sentiment_score=float(result['sentiment_score']),
keywords=result['keywords'], keywords=result['keywords'],
key_points=result['key_points'], key_points=result['key_points'],
influence_analysis=result['influence_analysis'], influence_analysis=result['influence_analysis'],
risk_level=result['risk_level'] risk_level=result['risk_level']
) )
session.add(analysis) session.add(analysis)
session.commit() session.commit()
except Exception as e:
logging.error(f"保存分析结果时出错: {e}")
return jsonify({
'success': False,
'error': '保存分析结果失败'
}), 500
# 格式化结果用于显示
display_results = [ display_results = [
ai_analyzer.format_analysis_for_display(result) ai_analyzer.format_analysis_for_display(result)
for result in analysis_results for result in analysis_results
@@ -359,27 +478,25 @@ async def analyze_messages():
return jsonify({ return jsonify({
'success': True, 'success': True,
'data': display_results, 'data': display_results
'meta': {
'total_messages': len(messages),
'analyzed_messages': len(analysis_results),
'batch_size': batch_size,
'model_type': model_type,
'analysis_depth': analysis_depth
}
}) })
except Exception as e: except Exception as e:
logging.error(f"AI分析过程出错: {e}") logging.error(f"分析消息时发生错误: {e}")
return jsonify({ return jsonify({
'success': False, 'success': False,
'error': str(e) 'error': str(e)
}), 500 }), 500
@pb.route('/api/get_analysis/<int:message_id>') @pb.route('/api/get_analysis/<int:message_id>')
@api_login_required
@rate_limit
def get_message_analysis(message_id): def get_message_analysis(message_id):
"""获取特定消息的分析结果""" """获取特定消息的分析结果"""
try: try:
if not message_id or message_id < 1:
return jsonify({'error': 'Invalid message ID'}), 400
with Session(engine) as session: with Session(engine) as session:
analysis = session.query(AIAnalysis)\ analysis = session.query(AIAnalysis)\
.filter(AIAnalysis.message_id == message_id)\ .filter(AIAnalysis.message_id == message_id)\