Comprehensive security enhancement, fix race conditions and injection vulnerabilities.

This commit is contained in:
戒酒的李白
2025-03-08 00:17:42 +08:00
parent 5630b30002
commit f81a71e970
3 changed files with 451 additions and 344 deletions
+263 -146
View File
@@ -1,4 +1,4 @@
from flask import Flask, session, render_template, redirect, Blueprint, request, jsonify
from flask import Flask, session, render_template, redirect, Blueprint, request, jsonify, abort
from utils.mynlp import SnowNLP
from utils.getHomePageData import *
from utils.getHotWordPageData import *
@@ -16,12 +16,60 @@ from sqlalchemy import create_engine
import asyncio
import torch
from BCAT_front.predict import model_manager
from functools import wraps
import bleach
import re
from datetime import datetime, timedelta
pb = Blueprint('page',
__name__,
url_prefix='/page',
template_folder='templates')
def sanitize_input(text):
"""清理用户输入,防止XSS攻击"""
if text is None:
return None
return bleach.clean(str(text), strip=True)
def validate_csrf_token():
"""验证CSRF令牌"""
token = request.form.get('csrf_token')
stored_token = session.get('csrf_token')
if not token or not stored_token or token != stored_token:
return False
return True
def login_required(f):
@wraps(f)
def decorated_function(*args, **kwargs):
if 'username' not in session:
return redirect('/user/login')
return f(*args, **kwargs)
return decorated_function
def api_login_required(f):
@wraps(f)
def decorated_function(*args, **kwargs):
if 'username' not in session:
return jsonify({'error': 'Unauthorized'}), 401
return f(*args, **kwargs)
return decorated_function
def rate_limit(f):
@wraps(f)
def decorated_function(*args, **kwargs):
key = f"rate_limit:{request.remote_addr}:{f.__name__}"
current = int(redis_client.get(key) or 0)
if current >= 100: # 每分钟100次请求限制
return jsonify({'error': 'Too many requests'}), 429
pipe = redis_client.pipeline()
pipe.incr(key)
pipe.expire(key, 60) # 60秒后重置
pipe.execute()
return f(*args, **kwargs)
return decorated_function
# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -37,14 +85,22 @@ except Exception as e:
logging.error(f"模型加载失败: {e}")
# 数据库配置
DATABASE_URL = "sqlite:///ai_analysis.db"
DATABASE_URL = os.getenv('DATABASE_URL', "sqlite:///ai_analysis.db")
engine = create_engine(DATABASE_URL)
AIAnalysis.metadata.create_all(engine)
def predict_sentiment(text):
"""使用改进版模型预测单个文本的情感"""
try:
predictions, probabilities = model_manager.predict_batch([text])
if not text or len(text.strip()) == 0:
return None, None
# 清理输入
cleaned_text = sanitize_input(text)
if not cleaned_text:
return None, None
predictions, probabilities = model_manager.predict_batch([cleaned_text])
if predictions is not None and len(predictions) > 0:
return predictions[0], probabilities[0][predictions[0]]
return None, None
@@ -53,55 +109,70 @@ def predict_sentiment(text):
return None, None
@pb.route('/home')
@login_required
def home():
username = session.get('username')
articleLenMax, likeCountMaxAuthorName, cityMax = getHomeTagsData()
commentsLikeCountTopFore = getHomeCommentsLikeCountTopFore()
X, Y = getHomeArticleCreatedAtChart()
typeChart = getHomeTypeChart()
createAtChart = getHomeCommentCreatedChart()
# getUserNameWordCloud()
return render_template('index.html',
username=username,
articleLenMax=articleLenMax,
likeCountMaxAuthorName=likeCountMaxAuthorName,
cityMax=cityMax,
commentsLikeCountTopFore=commentsLikeCountTopFore,
xData=X,
yData=Y,
typeChart=typeChart,
createAtChart=createAtChart)
try:
username = session.get('username')
articleLenMax, likeCountMaxAuthorName, cityMax = getHomeTagsData()
commentsLikeCountTopFore = getHomeCommentsLikeCountTopFore()
X, Y = getHomeArticleCreatedAtChart()
typeChart = getHomeTypeChart()
createAtChart = getHomeCommentCreatedChart()
return render_template('index.html',
username=username,
articleLenMax=articleLenMax,
likeCountMaxAuthorName=likeCountMaxAuthorName,
cityMax=cityMax,
commentsLikeCountTopFore=commentsLikeCountTopFore,
xData=X,
yData=Y,
typeChart=typeChart,
createAtChart=createAtChart)
except Exception as e:
logging.error(f"加载首页时发生错误: {e}")
return render_template('error.html', error_message="加载首页失败")
@pb.route('/hotWord')
@login_required
def hotWord():
username = session.get('username')
hotWordList = getAllHotWords()
print(hotWordList)
defaultHotWord = hotWordList[0][0]
if request.args.get('hotWord'):
defaultHotWord = request.args.get('hotWord')
hotWordLen = getHotWordLen(defaultHotWord)
X, Y = getHotWordPageCreatedAtCharData(defaultHotWord)
sentences = ''
value = SnowNLP(defaultHotWord).sentiments
if value == 0.5:
sentences = '中性'
elif value > 0.5:
sentences = '正面'
elif value < 0.5:
sentences = '负面'
comments = getCommentFilterData(defaultHotWord)
return render_template('hotWord.html',
username=username,
hotWordList=hotWordList,
defaultHotWord=defaultHotWord,
hotWordLen=hotWordLen,
sentences=sentences,
xData=X,
yData=Y,
comments=comments)
try:
username = session.get('username')
hotWordList = getAllHotWords()
if not hotWordList:
return render_template('error.html', error_message="无法获取热词列表")
defaultHotWord = sanitize_input(request.args.get('hotWord', hotWordList[0][0]))
# 验证热词是否在列表中
if not any(defaultHotWord in word for word in hotWordList):
return abort(400, "无效的热词")
hotWordLen = getHotWordLen(defaultHotWord)
X, Y = getHotWordPageCreatedAtCharData(defaultHotWord)
value = SnowNLP(defaultHotWord).sentiments
if value == 0.5:
sentences = '中性'
elif value > 0.5:
sentences = '正面'
elif value < 0.5:
sentences = '负面'
comments = getCommentFilterData(defaultHotWord)
return render_template('hotWord.html',
username=username,
hotWordList=hotWordList,
defaultHotWord=defaultHotWord,
hotWordLen=hotWordLen,
sentences=sentences,
xData=X,
yData=Y,
comments=comments)
except Exception as e:
logging.error(f"加载热词页面时发生错误: {e}")
return render_template('error.html', error_message="加载热词页面失败")
@pb.route('/hotTopic')
def hotTopic():
@@ -127,18 +198,21 @@ def hotTopic():
yData=Y,
comments=comments)
@pb.route('/tableData')
@login_required
def tableData():
username = session.get('username')
defaultFlag = False
if request.args.get('flag'): defaultFlag = True
tableData = getTableDataList(defaultFlag)
return render_template('tableData.html',
username=username,
tableData=tableData,
defaultFlag=defaultFlag)
try:
username = session.get('username')
defaultFlag = bool(request.args.get('flag', False))
tableData = getTableDataList(defaultFlag)
return render_template('tableData.html',
username=username,
tableData=tableData,
defaultFlag=defaultFlag)
except Exception as e:
logging.error(f"加载表格数据时发生错误: {e}")
return render_template('error.html', error_message="加载表格数据失败")
@pb.route('/articleChar')
def articleChar():
@@ -160,63 +234,89 @@ def articleChar():
x2Data=x2Data,
y2Data=y2Data)
@pb.route('/ipChar')
@login_required
def ipChar():
username = session.get('username')
articleRegionData = getIPByArticleRegion()
commentRegionData = getIPByCommentsRegion()
return render_template('ipChar.html',
username=username,
articleRegionData=articleRegionData,
commentRegionData=commentRegionData)
try:
username = session.get('username')
articleRegionData = getIPByArticleRegion()
commentRegionData = getIPByCommentsRegion()
return render_template('ipChar.html',
username=username,
articleRegionData=articleRegionData,
commentRegionData=commentRegionData)
except Exception as e:
logging.error(f"加载IP统计时发生错误: {e}")
return render_template('error.html', error_message="加载IP统计失败")
@pb.route('/commentChar')
@login_required
def commentChar():
username = session.get('username')
X, Y = getCommentDataOne()
genderPieData = getCommentDataTwo()
return render_template('commentChar.html',
username=username,
xData=X,
yData=Y,
genderPieData=genderPieData)
try:
username = session.get('username')
X, Y = getCommentDataOne()
genderPieData = getCommentDataTwo()
return render_template('commentChar.html',
username=username,
xData=X,
yData=Y,
genderPieData=genderPieData)
except Exception as e:
logging.error(f"加载评论统计时发生错误: {e}")
return render_template('error.html', error_message="加载评论统计失败")
@pb.route('/yuqingChar')
@login_required
def yuqingChar():
username = session.get('username')
# 获取模型选择参数
model_type = request.args.get('model', 'pro') # 默认使用改进模型
X, Y, biedata = getYuQingCharDataOne()
biedata1, biedata2 = getYuQingCharDataTwo(model_type)
x1Data, y1Data = getYuQingCharDataThree()
return render_template('yuqingChar.html',
username=username,
xData=X,
yData=Y,
biedata=biedata,
biedata1=biedata1,
biedata2=biedata2,
x1Data=x1Data,
y1Data=y1Data,
model_type=model_type)
try:
username = session.get('username')
model_type = sanitize_input(request.args.get('model', 'pro'))
# 验证模型类型
if model_type not in ['pro', 'basic']:
return abort(400, "无效的模型类型")
X, Y, biedata = getYuQingCharDataOne()
biedata1, biedata2 = getYuQingCharDataTwo(model_type)
x1Data, y1Data = getYuQingCharDataThree()
return render_template('yuqingChar.html',
username=username,
xData=X,
yData=Y,
biedata=biedata,
biedata1=biedata1,
biedata2=biedata2,
x1Data=x1Data,
y1Data=y1Data,
model_type=model_type)
except Exception as e:
logging.error(f"加载舆情统计时发生错误: {e}")
return render_template('error.html', error_message="加载舆情统计失败")
@pb.route('/yuqingpredict')
@login_required
def yuqingpredict():
try:
username = session.get('username')
TopicList = getAllTopicData()
defaultTopic = TopicList[0][0]
if request.args.get('Topic'):
defaultTopic = request.args.get('Topic')
if not TopicList:
return render_template('error.html', error_message="无法获取话题列表")
defaultTopic = sanitize_input(request.args.get('Topic', TopicList[0][0]))
# 验证话题是否在列表中
if not any(defaultTopic in topic for topic in TopicList):
return abort(400, "无效的话题")
TopicLen = getTopicLen(defaultTopic)
X, Y = getTopicCreatedAtandpredictData(defaultTopic)
# 获取模型选择参数
model_type = request.args.get('model', 'pro') # 默认使用改进模型
model_type = sanitize_input(request.args.get('model', 'pro'))
if model_type not in ['pro', 'basic']:
return abort(400, "无效的模型类型")
# 尝试从缓存获取预测结果
cache_key = f"{defaultTopic}_{model_type}"
@@ -226,7 +326,6 @@ def yuqingpredict():
sentences = cached_result
else:
if model_type == 'basic':
# 使用基础模型(SnowNLP
value = SnowNLP(defaultTopic).sentiments
if value == 0.5:
sentences = '中性'
@@ -235,7 +334,6 @@ def yuqingpredict():
elif value < 0.5:
sentences = '负面'
else:
# 使用改进模型
predicted_label, confidence = predict_sentiment(defaultTopic)
if predicted_label is not None:
sentences = '良好' if predicted_label == 0 else '不良'
@@ -248,26 +346,30 @@ def yuqingpredict():
prediction_cache.set(cache_key, sentences)
comments = getCommentFilterDataTopic(defaultTopic)
return render_template('yuqingpredict.html',
username=username,
hotWordList=TopicList,
defaultHotWord=defaultTopic,
hotWordLen=TopicLen,
sentences=sentences,
xData=X,
yData=Y,
comments=comments,
model_type=model_type)
username=username,
TopicList=TopicList,
defaultTopic=defaultTopic,
TopicLen=TopicLen,
sentences=sentences,
xData=X,
yData=Y,
comments=comments,
model_type=model_type)
except Exception as e:
logging.error(f"舆情预测页面渲染失败: {e}")
return render_template('error.html', error_message="加载舆情预测页面失败,请稍后重试")
logging.error(f"加载舆情预测时发生错误: {e}")
return render_template('error.html', error_message="加载舆情预测失败")
@pb.route('/articleCloud')
@login_required
def articleCloud():
username = session.get('username')
return render_template('articleContentCloud.html', username=username)
try:
username = session.get('username')
return render_template('articleContentCloud.html', username=username)
except Exception as e:
logging.error(f"加载文章云图时发生错误: {e}")
return render_template('error.html', error_message="加载文章云图失败")
@pb.route('/page/index')
def index():
@@ -306,15 +408,28 @@ def articleChar(id):
return render_template('error.html', error_message="加载文章详情失败")
@pb.route('/api/analyze_messages', methods=['POST'])
@api_login_required
@rate_limit
async def analyze_messages():
try:
# 获取请求参数
if not validate_csrf_token():
return jsonify({'error': 'Invalid CSRF token'}), 403
data = request.get_json()
batch_size = data.get('batch_size', 50)
model_type = data.get('model_type', 'gpt-3.5-turbo')
analysis_depth = data.get('analysis_depth', 'standard')
if not data:
return jsonify({'error': 'No data provided'}), 400
batch_size = min(int(data.get('batch_size', 50)), 100) # 限制批量大小
model_type = sanitize_input(data.get('model_type', 'gpt-3.5-turbo'))
analysis_depth = sanitize_input(data.get('analysis_depth', 'standard'))
# 验证参数
if model_type not in ['gpt-3.5-turbo', 'gpt-4']:
return jsonify({'error': 'Invalid model type'}), 400
if analysis_depth not in ['basic', 'standard', 'deep']:
return jsonify({'error': 'Invalid analysis depth'}), 400
# 获取最近的消息
messages = getRecentMessages(batch_size)
if not messages:
return jsonify({
@@ -322,7 +437,6 @@ async def analyze_messages():
'error': '没有找到需要分析的消息'
}), 404
# 调用AI进行分析
analysis_results = await ai_analyzer.analyze_messages(
messages=messages,
batch_size=batch_size,
@@ -336,22 +450,27 @@ async def analyze_messages():
'error': '分析过程中出现错误'
}), 500
# 保存到数据库
with Session(engine) as session:
for result in analysis_results:
analysis = AIAnalysis(
message_id=result['message_id'],
sentiment=result['sentiment'],
sentiment_score=float(result['sentiment_score']),
keywords=result['keywords'],
key_points=result['key_points'],
influence_analysis=result['influence_analysis'],
risk_level=result['risk_level']
)
session.add(analysis)
session.commit()
try:
with Session(engine) as session:
for result in analysis_results:
analysis = AIAnalysis(
message_id=result['message_id'],
sentiment=result['sentiment'],
sentiment_score=float(result['sentiment_score']),
keywords=result['keywords'],
key_points=result['key_points'],
influence_analysis=result['influence_analysis'],
risk_level=result['risk_level']
)
session.add(analysis)
session.commit()
except Exception as e:
logging.error(f"保存分析结果时出错: {e}")
return jsonify({
'success': False,
'error': '保存分析结果失败'
}), 500
# 格式化结果用于显示
display_results = [
ai_analyzer.format_analysis_for_display(result)
for result in analysis_results
@@ -359,27 +478,25 @@ async def analyze_messages():
return jsonify({
'success': True,
'data': display_results,
'meta': {
'total_messages': len(messages),
'analyzed_messages': len(analysis_results),
'batch_size': batch_size,
'model_type': model_type,
'analysis_depth': analysis_depth
}
'data': display_results
})
except Exception as e:
logging.error(f"AI分析过程出错: {e}")
logging.error(f"分析消息时发生错误: {e}")
return jsonify({
'success': False,
'error': str(e)
}), 500
@pb.route('/api/get_analysis/<int:message_id>')
@api_login_required
@rate_limit
def get_message_analysis(message_id):
"""获取特定消息的分析结果"""
try:
if not message_id or message_id < 1:
return jsonify({'error': 'Invalid message ID'}), 400
with Session(engine) as session:
analysis = session.query(AIAnalysis)\
.filter(AIAnalysis.message_id == message_id)\