Comprehensive security enhancement, fix race conditions and injection vulnerabilities.
This commit is contained in:
+263
-146
@@ -1,4 +1,4 @@
|
||||
from flask import Flask, session, render_template, redirect, Blueprint, request, jsonify
|
||||
from flask import Flask, session, render_template, redirect, Blueprint, request, jsonify, abort
|
||||
from utils.mynlp import SnowNLP
|
||||
from utils.getHomePageData import *
|
||||
from utils.getHotWordPageData import *
|
||||
@@ -16,12 +16,60 @@ from sqlalchemy import create_engine
|
||||
import asyncio
|
||||
import torch
|
||||
from BCAT_front.predict import model_manager
|
||||
from functools import wraps
|
||||
import bleach
|
||||
import re
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
pb = Blueprint('page',
|
||||
__name__,
|
||||
url_prefix='/page',
|
||||
template_folder='templates')
|
||||
|
||||
def sanitize_input(text):
|
||||
"""清理用户输入,防止XSS攻击"""
|
||||
if text is None:
|
||||
return None
|
||||
return bleach.clean(str(text), strip=True)
|
||||
|
||||
def validate_csrf_token():
|
||||
"""验证CSRF令牌"""
|
||||
token = request.form.get('csrf_token')
|
||||
stored_token = session.get('csrf_token')
|
||||
if not token or not stored_token or token != stored_token:
|
||||
return False
|
||||
return True
|
||||
|
||||
def login_required(f):
|
||||
@wraps(f)
|
||||
def decorated_function(*args, **kwargs):
|
||||
if 'username' not in session:
|
||||
return redirect('/user/login')
|
||||
return f(*args, **kwargs)
|
||||
return decorated_function
|
||||
|
||||
def api_login_required(f):
|
||||
@wraps(f)
|
||||
def decorated_function(*args, **kwargs):
|
||||
if 'username' not in session:
|
||||
return jsonify({'error': 'Unauthorized'}), 401
|
||||
return f(*args, **kwargs)
|
||||
return decorated_function
|
||||
|
||||
def rate_limit(f):
|
||||
@wraps(f)
|
||||
def decorated_function(*args, **kwargs):
|
||||
key = f"rate_limit:{request.remote_addr}:{f.__name__}"
|
||||
current = int(redis_client.get(key) or 0)
|
||||
if current >= 100: # 每分钟100次请求限制
|
||||
return jsonify({'error': 'Too many requests'}), 429
|
||||
pipe = redis_client.pipeline()
|
||||
pipe.incr(key)
|
||||
pipe.expire(key, 60) # 60秒后重置
|
||||
pipe.execute()
|
||||
return f(*args, **kwargs)
|
||||
return decorated_function
|
||||
|
||||
# 设置设备
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
@@ -37,14 +85,22 @@ except Exception as e:
|
||||
logging.error(f"模型加载失败: {e}")
|
||||
|
||||
# 数据库配置
|
||||
DATABASE_URL = "sqlite:///ai_analysis.db"
|
||||
DATABASE_URL = os.getenv('DATABASE_URL', "sqlite:///ai_analysis.db")
|
||||
engine = create_engine(DATABASE_URL)
|
||||
AIAnalysis.metadata.create_all(engine)
|
||||
|
||||
def predict_sentiment(text):
|
||||
"""使用改进版模型预测单个文本的情感"""
|
||||
try:
|
||||
predictions, probabilities = model_manager.predict_batch([text])
|
||||
if not text or len(text.strip()) == 0:
|
||||
return None, None
|
||||
|
||||
# 清理输入
|
||||
cleaned_text = sanitize_input(text)
|
||||
if not cleaned_text:
|
||||
return None, None
|
||||
|
||||
predictions, probabilities = model_manager.predict_batch([cleaned_text])
|
||||
if predictions is not None and len(predictions) > 0:
|
||||
return predictions[0], probabilities[0][predictions[0]]
|
||||
return None, None
|
||||
@@ -53,55 +109,70 @@ def predict_sentiment(text):
|
||||
return None, None
|
||||
|
||||
@pb.route('/home')
|
||||
@login_required
|
||||
def home():
|
||||
username = session.get('username')
|
||||
articleLenMax, likeCountMaxAuthorName, cityMax = getHomeTagsData()
|
||||
commentsLikeCountTopFore = getHomeCommentsLikeCountTopFore()
|
||||
X, Y = getHomeArticleCreatedAtChart()
|
||||
typeChart = getHomeTypeChart()
|
||||
createAtChart = getHomeCommentCreatedChart()
|
||||
# getUserNameWordCloud()
|
||||
return render_template('index.html',
|
||||
username=username,
|
||||
articleLenMax=articleLenMax,
|
||||
likeCountMaxAuthorName=likeCountMaxAuthorName,
|
||||
cityMax=cityMax,
|
||||
commentsLikeCountTopFore=commentsLikeCountTopFore,
|
||||
xData=X,
|
||||
yData=Y,
|
||||
typeChart=typeChart,
|
||||
createAtChart=createAtChart)
|
||||
|
||||
try:
|
||||
username = session.get('username')
|
||||
articleLenMax, likeCountMaxAuthorName, cityMax = getHomeTagsData()
|
||||
commentsLikeCountTopFore = getHomeCommentsLikeCountTopFore()
|
||||
X, Y = getHomeArticleCreatedAtChart()
|
||||
typeChart = getHomeTypeChart()
|
||||
createAtChart = getHomeCommentCreatedChart()
|
||||
|
||||
return render_template('index.html',
|
||||
username=username,
|
||||
articleLenMax=articleLenMax,
|
||||
likeCountMaxAuthorName=likeCountMaxAuthorName,
|
||||
cityMax=cityMax,
|
||||
commentsLikeCountTopFore=commentsLikeCountTopFore,
|
||||
xData=X,
|
||||
yData=Y,
|
||||
typeChart=typeChart,
|
||||
createAtChart=createAtChart)
|
||||
except Exception as e:
|
||||
logging.error(f"加载首页时发生错误: {e}")
|
||||
return render_template('error.html', error_message="加载首页失败")
|
||||
|
||||
@pb.route('/hotWord')
|
||||
@login_required
|
||||
def hotWord():
|
||||
username = session.get('username')
|
||||
hotWordList = getAllHotWords()
|
||||
print(hotWordList)
|
||||
defaultHotWord = hotWordList[0][0]
|
||||
if request.args.get('hotWord'):
|
||||
defaultHotWord = request.args.get('hotWord')
|
||||
hotWordLen = getHotWordLen(defaultHotWord)
|
||||
X, Y = getHotWordPageCreatedAtCharData(defaultHotWord)
|
||||
sentences = ''
|
||||
value = SnowNLP(defaultHotWord).sentiments
|
||||
if value == 0.5:
|
||||
sentences = '中性'
|
||||
elif value > 0.5:
|
||||
sentences = '正面'
|
||||
elif value < 0.5:
|
||||
sentences = '负面'
|
||||
comments = getCommentFilterData(defaultHotWord)
|
||||
return render_template('hotWord.html',
|
||||
username=username,
|
||||
hotWordList=hotWordList,
|
||||
defaultHotWord=defaultHotWord,
|
||||
hotWordLen=hotWordLen,
|
||||
sentences=sentences,
|
||||
xData=X,
|
||||
yData=Y,
|
||||
comments=comments)
|
||||
|
||||
try:
|
||||
username = session.get('username')
|
||||
hotWordList = getAllHotWords()
|
||||
if not hotWordList:
|
||||
return render_template('error.html', error_message="无法获取热词列表")
|
||||
|
||||
defaultHotWord = sanitize_input(request.args.get('hotWord', hotWordList[0][0]))
|
||||
|
||||
# 验证热词是否在列表中
|
||||
if not any(defaultHotWord in word for word in hotWordList):
|
||||
return abort(400, "无效的热词")
|
||||
|
||||
hotWordLen = getHotWordLen(defaultHotWord)
|
||||
X, Y = getHotWordPageCreatedAtCharData(defaultHotWord)
|
||||
|
||||
value = SnowNLP(defaultHotWord).sentiments
|
||||
if value == 0.5:
|
||||
sentences = '中性'
|
||||
elif value > 0.5:
|
||||
sentences = '正面'
|
||||
elif value < 0.5:
|
||||
sentences = '负面'
|
||||
|
||||
comments = getCommentFilterData(defaultHotWord)
|
||||
|
||||
return render_template('hotWord.html',
|
||||
username=username,
|
||||
hotWordList=hotWordList,
|
||||
defaultHotWord=defaultHotWord,
|
||||
hotWordLen=hotWordLen,
|
||||
sentences=sentences,
|
||||
xData=X,
|
||||
yData=Y,
|
||||
comments=comments)
|
||||
except Exception as e:
|
||||
logging.error(f"加载热词页面时发生错误: {e}")
|
||||
return render_template('error.html', error_message="加载热词页面失败")
|
||||
|
||||
@pb.route('/hotTopic')
|
||||
def hotTopic():
|
||||
@@ -127,18 +198,21 @@ def hotTopic():
|
||||
yData=Y,
|
||||
comments=comments)
|
||||
|
||||
|
||||
@pb.route('/tableData')
|
||||
@login_required
|
||||
def tableData():
|
||||
username = session.get('username')
|
||||
defaultFlag = False
|
||||
if request.args.get('flag'): defaultFlag = True
|
||||
tableData = getTableDataList(defaultFlag)
|
||||
return render_template('tableData.html',
|
||||
username=username,
|
||||
tableData=tableData,
|
||||
defaultFlag=defaultFlag)
|
||||
|
||||
try:
|
||||
username = session.get('username')
|
||||
defaultFlag = bool(request.args.get('flag', False))
|
||||
tableData = getTableDataList(defaultFlag)
|
||||
|
||||
return render_template('tableData.html',
|
||||
username=username,
|
||||
tableData=tableData,
|
||||
defaultFlag=defaultFlag)
|
||||
except Exception as e:
|
||||
logging.error(f"加载表格数据时发生错误: {e}")
|
||||
return render_template('error.html', error_message="加载表格数据失败")
|
||||
|
||||
@pb.route('/articleChar')
|
||||
def articleChar():
|
||||
@@ -160,63 +234,89 @@ def articleChar():
|
||||
x2Data=x2Data,
|
||||
y2Data=y2Data)
|
||||
|
||||
|
||||
@pb.route('/ipChar')
|
||||
@login_required
|
||||
def ipChar():
|
||||
username = session.get('username')
|
||||
articleRegionData = getIPByArticleRegion()
|
||||
commentRegionData = getIPByCommentsRegion()
|
||||
return render_template('ipChar.html',
|
||||
username=username,
|
||||
articleRegionData=articleRegionData,
|
||||
commentRegionData=commentRegionData)
|
||||
|
||||
try:
|
||||
username = session.get('username')
|
||||
articleRegionData = getIPByArticleRegion()
|
||||
commentRegionData = getIPByCommentsRegion()
|
||||
|
||||
return render_template('ipChar.html',
|
||||
username=username,
|
||||
articleRegionData=articleRegionData,
|
||||
commentRegionData=commentRegionData)
|
||||
except Exception as e:
|
||||
logging.error(f"加载IP统计时发生错误: {e}")
|
||||
return render_template('error.html', error_message="加载IP统计失败")
|
||||
|
||||
@pb.route('/commentChar')
|
||||
@login_required
|
||||
def commentChar():
|
||||
username = session.get('username')
|
||||
X, Y = getCommentDataOne()
|
||||
genderPieData = getCommentDataTwo()
|
||||
return render_template('commentChar.html',
|
||||
username=username,
|
||||
xData=X,
|
||||
yData=Y,
|
||||
genderPieData=genderPieData)
|
||||
|
||||
try:
|
||||
username = session.get('username')
|
||||
X, Y = getCommentDataOne()
|
||||
genderPieData = getCommentDataTwo()
|
||||
|
||||
return render_template('commentChar.html',
|
||||
username=username,
|
||||
xData=X,
|
||||
yData=Y,
|
||||
genderPieData=genderPieData)
|
||||
except Exception as e:
|
||||
logging.error(f"加载评论统计时发生错误: {e}")
|
||||
return render_template('error.html', error_message="加载评论统计失败")
|
||||
|
||||
@pb.route('/yuqingChar')
|
||||
@login_required
|
||||
def yuqingChar():
|
||||
username = session.get('username')
|
||||
# 获取模型选择参数
|
||||
model_type = request.args.get('model', 'pro') # 默认使用改进模型
|
||||
|
||||
X, Y, biedata = getYuQingCharDataOne()
|
||||
biedata1, biedata2 = getYuQingCharDataTwo(model_type)
|
||||
x1Data, y1Data = getYuQingCharDataThree()
|
||||
return render_template('yuqingChar.html',
|
||||
username=username,
|
||||
xData=X,
|
||||
yData=Y,
|
||||
biedata=biedata,
|
||||
biedata1=biedata1,
|
||||
biedata2=biedata2,
|
||||
x1Data=x1Data,
|
||||
y1Data=y1Data,
|
||||
model_type=model_type)
|
||||
try:
|
||||
username = session.get('username')
|
||||
model_type = sanitize_input(request.args.get('model', 'pro'))
|
||||
|
||||
# 验证模型类型
|
||||
if model_type not in ['pro', 'basic']:
|
||||
return abort(400, "无效的模型类型")
|
||||
|
||||
X, Y, biedata = getYuQingCharDataOne()
|
||||
biedata1, biedata2 = getYuQingCharDataTwo(model_type)
|
||||
x1Data, y1Data = getYuQingCharDataThree()
|
||||
|
||||
return render_template('yuqingChar.html',
|
||||
username=username,
|
||||
xData=X,
|
||||
yData=Y,
|
||||
biedata=biedata,
|
||||
biedata1=biedata1,
|
||||
biedata2=biedata2,
|
||||
x1Data=x1Data,
|
||||
y1Data=y1Data,
|
||||
model_type=model_type)
|
||||
except Exception as e:
|
||||
logging.error(f"加载舆情统计时发生错误: {e}")
|
||||
return render_template('error.html', error_message="加载舆情统计失败")
|
||||
|
||||
@pb.route('/yuqingpredict')
|
||||
@login_required
|
||||
def yuqingpredict():
|
||||
try:
|
||||
username = session.get('username')
|
||||
TopicList = getAllTopicData()
|
||||
defaultTopic = TopicList[0][0]
|
||||
if request.args.get('Topic'):
|
||||
defaultTopic = request.args.get('Topic')
|
||||
if not TopicList:
|
||||
return render_template('error.html', error_message="无法获取话题列表")
|
||||
|
||||
defaultTopic = sanitize_input(request.args.get('Topic', TopicList[0][0]))
|
||||
|
||||
# 验证话题是否在列表中
|
||||
if not any(defaultTopic in topic for topic in TopicList):
|
||||
return abort(400, "无效的话题")
|
||||
|
||||
TopicLen = getTopicLen(defaultTopic)
|
||||
X, Y = getTopicCreatedAtandpredictData(defaultTopic)
|
||||
|
||||
# 获取模型选择参数
|
||||
model_type = request.args.get('model', 'pro') # 默认使用改进模型
|
||||
model_type = sanitize_input(request.args.get('model', 'pro'))
|
||||
if model_type not in ['pro', 'basic']:
|
||||
return abort(400, "无效的模型类型")
|
||||
|
||||
# 尝试从缓存获取预测结果
|
||||
cache_key = f"{defaultTopic}_{model_type}"
|
||||
@@ -226,7 +326,6 @@ def yuqingpredict():
|
||||
sentences = cached_result
|
||||
else:
|
||||
if model_type == 'basic':
|
||||
# 使用基础模型(SnowNLP)
|
||||
value = SnowNLP(defaultTopic).sentiments
|
||||
if value == 0.5:
|
||||
sentences = '中性'
|
||||
@@ -235,7 +334,6 @@ def yuqingpredict():
|
||||
elif value < 0.5:
|
||||
sentences = '负面'
|
||||
else:
|
||||
# 使用改进模型
|
||||
predicted_label, confidence = predict_sentiment(defaultTopic)
|
||||
if predicted_label is not None:
|
||||
sentences = '良好' if predicted_label == 0 else '不良'
|
||||
@@ -248,26 +346,30 @@ def yuqingpredict():
|
||||
prediction_cache.set(cache_key, sentences)
|
||||
|
||||
comments = getCommentFilterDataTopic(defaultTopic)
|
||||
|
||||
return render_template('yuqingpredict.html',
|
||||
username=username,
|
||||
hotWordList=TopicList,
|
||||
defaultHotWord=defaultTopic,
|
||||
hotWordLen=TopicLen,
|
||||
sentences=sentences,
|
||||
xData=X,
|
||||
yData=Y,
|
||||
comments=comments,
|
||||
model_type=model_type)
|
||||
username=username,
|
||||
TopicList=TopicList,
|
||||
defaultTopic=defaultTopic,
|
||||
TopicLen=TopicLen,
|
||||
sentences=sentences,
|
||||
xData=X,
|
||||
yData=Y,
|
||||
comments=comments,
|
||||
model_type=model_type)
|
||||
except Exception as e:
|
||||
logging.error(f"舆情预测页面渲染失败: {e}")
|
||||
return render_template('error.html', error_message="加载舆情预测页面失败,请稍后重试")
|
||||
|
||||
logging.error(f"加载舆情预测时发生错误: {e}")
|
||||
return render_template('error.html', error_message="加载舆情预测失败")
|
||||
|
||||
@pb.route('/articleCloud')
|
||||
@login_required
|
||||
def articleCloud():
|
||||
username = session.get('username')
|
||||
return render_template('articleContentCloud.html', username=username)
|
||||
|
||||
try:
|
||||
username = session.get('username')
|
||||
return render_template('articleContentCloud.html', username=username)
|
||||
except Exception as e:
|
||||
logging.error(f"加载文章云图时发生错误: {e}")
|
||||
return render_template('error.html', error_message="加载文章云图失败")
|
||||
|
||||
@pb.route('/page/index')
|
||||
def index():
|
||||
@@ -306,15 +408,28 @@ def articleChar(id):
|
||||
return render_template('error.html', error_message="加载文章详情失败")
|
||||
|
||||
@pb.route('/api/analyze_messages', methods=['POST'])
|
||||
@api_login_required
|
||||
@rate_limit
|
||||
async def analyze_messages():
|
||||
try:
|
||||
# 获取请求参数
|
||||
if not validate_csrf_token():
|
||||
return jsonify({'error': 'Invalid CSRF token'}), 403
|
||||
|
||||
data = request.get_json()
|
||||
batch_size = data.get('batch_size', 50)
|
||||
model_type = data.get('model_type', 'gpt-3.5-turbo')
|
||||
analysis_depth = data.get('analysis_depth', 'standard')
|
||||
if not data:
|
||||
return jsonify({'error': 'No data provided'}), 400
|
||||
|
||||
batch_size = min(int(data.get('batch_size', 50)), 100) # 限制批量大小
|
||||
model_type = sanitize_input(data.get('model_type', 'gpt-3.5-turbo'))
|
||||
analysis_depth = sanitize_input(data.get('analysis_depth', 'standard'))
|
||||
|
||||
# 验证参数
|
||||
if model_type not in ['gpt-3.5-turbo', 'gpt-4']:
|
||||
return jsonify({'error': 'Invalid model type'}), 400
|
||||
|
||||
if analysis_depth not in ['basic', 'standard', 'deep']:
|
||||
return jsonify({'error': 'Invalid analysis depth'}), 400
|
||||
|
||||
# 获取最近的消息
|
||||
messages = getRecentMessages(batch_size)
|
||||
if not messages:
|
||||
return jsonify({
|
||||
@@ -322,7 +437,6 @@ async def analyze_messages():
|
||||
'error': '没有找到需要分析的消息'
|
||||
}), 404
|
||||
|
||||
# 调用AI进行分析
|
||||
analysis_results = await ai_analyzer.analyze_messages(
|
||||
messages=messages,
|
||||
batch_size=batch_size,
|
||||
@@ -336,22 +450,27 @@ async def analyze_messages():
|
||||
'error': '分析过程中出现错误'
|
||||
}), 500
|
||||
|
||||
# 保存到数据库
|
||||
with Session(engine) as session:
|
||||
for result in analysis_results:
|
||||
analysis = AIAnalysis(
|
||||
message_id=result['message_id'],
|
||||
sentiment=result['sentiment'],
|
||||
sentiment_score=float(result['sentiment_score']),
|
||||
keywords=result['keywords'],
|
||||
key_points=result['key_points'],
|
||||
influence_analysis=result['influence_analysis'],
|
||||
risk_level=result['risk_level']
|
||||
)
|
||||
session.add(analysis)
|
||||
session.commit()
|
||||
try:
|
||||
with Session(engine) as session:
|
||||
for result in analysis_results:
|
||||
analysis = AIAnalysis(
|
||||
message_id=result['message_id'],
|
||||
sentiment=result['sentiment'],
|
||||
sentiment_score=float(result['sentiment_score']),
|
||||
keywords=result['keywords'],
|
||||
key_points=result['key_points'],
|
||||
influence_analysis=result['influence_analysis'],
|
||||
risk_level=result['risk_level']
|
||||
)
|
||||
session.add(analysis)
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
logging.error(f"保存分析结果时出错: {e}")
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'error': '保存分析结果失败'
|
||||
}), 500
|
||||
|
||||
# 格式化结果用于显示
|
||||
display_results = [
|
||||
ai_analyzer.format_analysis_for_display(result)
|
||||
for result in analysis_results
|
||||
@@ -359,27 +478,25 @@ async def analyze_messages():
|
||||
|
||||
return jsonify({
|
||||
'success': True,
|
||||
'data': display_results,
|
||||
'meta': {
|
||||
'total_messages': len(messages),
|
||||
'analyzed_messages': len(analysis_results),
|
||||
'batch_size': batch_size,
|
||||
'model_type': model_type,
|
||||
'analysis_depth': analysis_depth
|
||||
}
|
||||
'data': display_results
|
||||
})
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"AI分析过程出错: {e}")
|
||||
logging.error(f"分析消息时发生错误: {e}")
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}), 500
|
||||
|
||||
@pb.route('/api/get_analysis/<int:message_id>')
|
||||
@api_login_required
|
||||
@rate_limit
|
||||
def get_message_analysis(message_id):
|
||||
"""获取特定消息的分析结果"""
|
||||
try:
|
||||
if not message_id or message_id < 1:
|
||||
return jsonify({'error': 'Invalid message ID'}), 400
|
||||
|
||||
with Session(engine) as session:
|
||||
analysis = session.query(AIAnalysis)\
|
||||
.filter(AIAnalysis.message_id == message_id)\
|
||||
|
||||
Reference in New Issue
Block a user