diff --git a/app.py b/app.py index 832d53d..7b18720 100644 --- a/app.py +++ b/app.py @@ -3,13 +3,20 @@ import re import getpass import pymysql import subprocess -from flask import Flask, session, request, redirect, render_template +from flask import Flask, session, request, redirect, render_template, jsonify from apscheduler.schedulers.background import BackgroundScheduler from pytz import utc from datetime import datetime, timedelta import time from utils.logger import app_logger as logging from utils.db_manager import DatabaseManager +import secrets +from dotenv import load_dotenv +from functools import wraps +import bleach + +# 加载环境变量 +load_dotenv() def get_db_connection_interactive(): """ @@ -18,17 +25,17 @@ def get_db_connection_interactive(): """ print("请依次输入数据库连接信息(直接按回车使用默认值):") - host = input(" 1. 主机 (默认: localhost): ") or "localhost" - port_str = input(" 2. 端口 (默认: 3306): ") or "3306" + host = input(" 1. 主机 (默认: localhost): ") or os.getenv('DB_HOST', 'localhost') + port_str = input(" 2. 端口 (默认: 3306): ") or os.getenv('DB_PORT', '3306') try: port = int(port_str) except ValueError: logging.warning("端口号无效,使用默认端口 3306。") port = 3306 - user = input(" 3. 用户名 (默认: root): ") or "root" - password = getpass.getpass(" 4. 密码 (默认: 12345678): ") or "12345678" - db_name = input(" 5. 数据库名 (默认: Weibo_PublicOpinion_AnalysisSystem): ") or "Weibo_PublicOpinion_AnalysisSystem" + user = input(" 3. 用户名 (默认: root): ") or os.getenv('DB_USER', 'root') + password = getpass.getpass(" 4. 密码: ") or os.getenv('DB_PASSWORD', '') + db_name = input(" 5. 数据库名 (默认: Weibo_PublicOpinion_AnalysisSystem): ") or os.getenv('DB_NAME', 'Weibo_PublicOpinion_AnalysisSystem') logging.info(f"尝试连接到数据库: {user}@{host}:{port}/{db_name}") @@ -40,237 +47,183 @@ def get_db_connection_interactive(): password=password, database=db_name, charset='utf8mb4', - cursorclass=pymysql.cursors.DictCursor # 返回字典格式 + cursorclass=pymysql.cursors.DictCursor, + ssl={'ssl': {'ca': os.getenv('DB_SSL_CA')}} if os.getenv('DB_SSL_CA') else None ) logging.info("数据库连接成功。") return connection except pymysql.MySQLError as e: logging.error(f"数据库连接失败: {e}") - exit(1) + raise -def initialize_database(connection, sql_file_path): - """ - 执行 SQL 文件中的语句以初始化数据库。 - - :param connection: 已建立的数据库连接 - :param sql_file_path: SQL 文件的路径 - """ - try: - with open(sql_file_path, 'r', encoding='utf8') as file: - sql_commands = file.read() - - with connection.cursor() as cursor: - for statement in sql_commands.split(';'): - statement = statement.strip() - if statement: - cursor.execute(statement) - connection.commit() - logging.info("数据库初始化成功。") - except FileNotFoundError: - logging.error(f"SQL 文件未找到: {sql_file_path}") - exit(1) - except pymysql.MySQLError as e: - logging.error(f"执行 SQL 时出错: {e}") - connection.rollback() - exit(1) - except Exception as e: - logging.error(f"初始化数据库时出错: {e}") - connection.rollback() - exit(1) +def sanitize_input(text): + """清理用户输入,防止XSS攻击""" + if text is None: + return None + return bleach.clean(str(text), strip=True) -def prompt_first_run(): - """ - 询问用户是否首次运行,需要初始化数据库。 - - :return: Boolean,True 表示需要初始化数据库 - """ - while True: - choice = input("是否首次运行该项目,需要初始化数据库?(Y/n): ").strip().lower() - if choice in ['y', 'yes', '']: - return True - elif choice in ['n', 'no']: - return False - else: - print("请输入 Y 或 N。") +def set_secure_headers(response): + """设置安全响应头""" + response.headers['X-Content-Type-Options'] = 'nosniff' + response.headers['X-Frame-Options'] = 'SAMEORIGIN' + response.headers['X-XSS-Protection'] = '1; mode=block' + response.headers['Strict-Transport-Security'] = 'max-age=31536000; includeSubDomains' + response.headers['Content-Security-Policy'] = "default-src 'self'" + return response # 初始化 Flask 应用 app = Flask(__name__) -app.secret_key = 'this is secret_key you know ?' # 设置 Flask 的密钥,用于 session 加密 +app.secret_key = os.getenv('FLASK_SECRET_KEY', secrets.token_hex(32)) +app.config['SESSION_COOKIE_SECURE'] = True +app.config['SESSION_COOKIE_HTTPONLY'] = True +app.config['SESSION_COOKIE_SAMESITE'] = 'Lax' +app.config['PERMANENT_SESSION_LIFETIME'] = timedelta(hours=2) # 导入蓝图 from views.page import page from views.user import user from views.spider_control import spider_bp -app.register_blueprint(page.pb) # 注册页面蓝图 -app.register_blueprint(user.ub) # 注册用户蓝图 -app.register_blueprint(spider_bp) # 注册爬虫控制蓝图 +app.register_blueprint(page.pb) +app.register_blueprint(user.ub) +app.register_blueprint(spider_bp) -# 首页路由,清空 session +# 首页路由 @app.route('/') def hello_world(): - session.clear() # 清空 session,用户退出登录 - return "Session Cleared" + session.clear() + return redirect('/user/login') -# 中间件:处理请求前的逻辑 +# 请求前中间件 @app.before_request def before_request(): + # 检查是否是HTTPS + if not request.is_secure and not app.debug: + url = request.url.replace('http://', 'https://', 1) + return redirect(url, code=301) + # 如果请求的是静态文件路径,允许访问 if request.path.startswith('/static'): return - + # 如果请求的是登录或注册页面,不需要会话验证 if request.path in ['/user/login', '/user/register']: return - - # 如果 session 中没有用户名,重定向到登录页面 + + # 验证会话 if not session.get('username'): return redirect('/user/login') - -# 404 错误页面路由 -@app.route('/') -def catch_all(path): - return render_template('404.html') # 如果路径不存在,返回 404 页面 - -# 定义定时任务,运行爬虫脚本 -def run_script(): - current_dir = os.path.dirname(os.path.abspath(__file__)) # 获取当前脚本的目录 - spider_script = os.path.join(current_dir, 'spider', 'main.py') # 爬虫脚本路径 - # cutComments_script = os.path.join(current_dir, 'utils', 'cutComments.py') # 评论处理脚本路径 - # cipingTotal_script = os.path.join(current_dir, 'utils', 'cipingTotal.py') # 评分处理脚本路径 - - # 定义所有要运行的脚本 - scripts = [ - ("Spider Script", spider_script), - # ("Cut Comments Script", cutComments_script), - # ("Ciping Total Script", cipingTotal_script) - ] - - # 执行所有脚本 - for script_name, script_path in scripts: - try: - logging.info(f"Running {script_name}...") - subprocess.run(['python', script_path], check=True) # 使用 subprocess 执行脚本 - logging.info(f"{script_name} finished successfully.") - except subprocess.CalledProcessError as e: - logging.error(f"An error occurred while running {script_name}: {e}") - -# 新增功能:动态调度爬虫脚本 -def check_database_empty(): - """ - 检查数据库中的指定表是否为空。 - :return: 如果表为空则返回 True,否则返回 False - """ - try: - connection = pymysql.connect(**DB_CONFIG) - with connection.cursor() as cursor: - cursor.execute("SELECT COUNT(*) as count FROM article") - result = cursor.fetchone() - count = result['count'] if result and 'count' in result else 0 - logging.info(f"数据库中共有 {count} 条记录。") - return count == 0 - except pymysql.MySQLError as e: - logging.error(f"检查数据库失败: {e}") - return True # 连接失败时假设数据库为空,以防止阻塞 - finally: - if 'connection' in locals(): - connection.close() - -def dynamic_crawl(): - """ - 执行爬取任务并根据爬取耗时和获取的数据量动态调度下次爬取时间。 - """ - try: - start_time = time.time() - logging.info("开始爬取数据。") + # 验证会话完整性 + if 'client_info' not in session: + session.clear() + return redirect('/user/login') - run_script() # 执行爬虫脚本 - - end_time = time.time() - duration = end_time - start_time # 爬取耗时 - - # 获取爬取后数据库中记录的数量作为数据量 - try: - connection = pymysql.connect(**DB_CONFIG) - with connection.cursor() as cursor: - cursor.execute("SELECT COUNT(*) as count FROM article") - result = cursor.fetchone() - data_fetched = result['count'] if result and 'count' in result else 0 - logging.info(f"爬取完成,耗时 {duration:.2f} 秒,数据库中共有 {data_fetched} 条记录。") - except pymysql.MySQLError as e: - logging.error(f"获取数据量失败: {e}") - data_fetched = 0 - finally: - if 'connection' in locals(): - connection.close() - - # 根据爬取耗时和数据量调整下次爬取时间 - base_interval = 5 * 60 * 60 # 5小时的基础时间间隔(秒) - - if duration > 3600: # 爬取耗时超过1小时 - next_interval = base_interval + duration - logging.info(f"检测到长时间爬取。下次爬取将在 {next_interval/3600:.2f} 小时后执行。") - elif data_fetched < 50: # 获取的数据量少于50条 - next_interval = base_interval / 2 - logging.info(f"获取数据量较少。下次爬取将在 {next_interval/60:.2f} 分钟后执行。") - else: - next_interval = base_interval - logging.info(f"标准爬取完成。下次爬取将在 {next_interval/3600:.2f} 小时后执行。") - - # 安排下次爬取任务 - scheduler.add_job(dynamic_crawl, 'date', run_date=datetime.now() + timedelta(seconds=next_interval), id='dynamic_crawl') + # 验证客户端信息 + current_client = { + 'ip': request.remote_addr, + 'user_agent': str(request.user_agent) + } + stored_client = session.get('client_info', {}) - except Exception as e: - logging.error(f"动态爬取过程中发生错误: {e}") + if (current_client['ip'] != stored_client.get('ip') or + current_client['user_agent'] != stored_client.get('user_agent')): + session.clear() + return redirect('/user/login') -# 数据库配置,用于动态调度功能 +# 响应后中间件 +@app.after_request +def after_request(response): + return set_secure_headers(response) + +# 错误处理 +@app.errorhandler(404) +def not_found_error(error): + return render_template('404.html'), 404 + +@app.errorhandler(500) +def internal_error(error): + return render_template('500.html'), 500 + +@app.errorhandler(403) +def forbidden_error(error): + return render_template('403.html'), 403 + +@app.errorhandler(400) +def bad_request_error(error): + return render_template('400.html'), 400 + +# 数据库配置 DB_CONFIG = { - 'host': 'localhost', - 'user': 'root', - 'password': '12345678', - 'database': 'Weibo_PublicOpinion_AnalysisSystem', - 'port': 3306, - 'charset': 'utf8mb4' + 'host': os.getenv('DB_HOST', 'localhost'), + 'user': os.getenv('DB_USER', 'root'), + 'password': os.getenv('DB_PASSWORD', ''), + 'database': os.getenv('DB_NAME', 'Weibo_PublicOpinion_AnalysisSystem'), + 'port': int(os.getenv('DB_PORT', '3306')), + 'charset': 'utf8mb4', + 'ssl': {'ca': os.getenv('DB_SSL_CA')} if os.getenv('DB_SSL_CA') else None } # 初始化数据库管理器 DatabaseManager.initialize(DB_CONFIG) -# 主程序入口 if __name__ == '__main__': # 检测是否需要初始化数据库 - if prompt_first_run(): - # 获取数据库连接 - connection = get_db_connection_interactive() - - # 执行数据库初始化 - sql_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'createTables.sql') - initialize_database(connection, sql_file) - - # 关闭数据库连接 - connection.close() - logging.info("数据库连接已关闭。") - - # 设置定时任务,动态执行爬虫脚本 - scheduler = BackgroundScheduler(timezone=utc) # 创建后台任务调度器 - scheduler.start() # 启动调度器 - - # 初始化调度:如果数据库为空,立即爬取;否则,按照基础时间间隔安排首次爬取 - if check_database_empty(): - logging.info("数据库为空。立即开始初始爬取。") - dynamic_crawl() - else: - logging.info("数据库已有数据。安排首次爬取。") - base_interval = 5 * 60 * 60 # 5小时 - scheduler.add_job(dynamic_crawl, 'date', run_date=datetime.now() + timedelta(seconds=base_interval), id='dynamic_crawl') - try: - app.run() # 启动 Flask 应用 - finally: - scheduler.shutdown() # 确保在应用关闭时关闭调度器 + if os.getenv('INITIALIZE_DB', 'false').lower() == 'true': + connection = get_db_connection_interactive() + sql_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'createTables.sql') + initialize_database(connection, sql_file) + connection.close() + logging.info("数据库初始化完成。") + except Exception as e: + logging.error(f"数据库初始化失败: {e}") + exit(1) -# 设置日志记录,捕获应用的请求信息 + # 设置定时任务 + try: + scheduler = BackgroundScheduler(timezone=utc) + scheduler.start() + + if check_database_empty(): + logging.info("数据库为空。立即开始初始爬取。") + dynamic_crawl() + else: + logging.info("数据库已有数据。安排首次爬取。") + base_interval = int(os.getenv('CRAWL_INTERVAL', '18000')) # 默认5小时 + scheduler.add_job( + dynamic_crawl, + 'date', + run_date=datetime.now() + timedelta(seconds=base_interval), + id='dynamic_crawl' + ) + + # 启动应用 + app.run( + host=os.getenv('FLASK_HOST', '127.0.0.1'), + port=int(os.getenv('FLASK_PORT', '5000')), + ssl_context='adhoc' if os.getenv('ENABLE_HTTPS', 'false').lower() == 'true' else None + ) + except Exception as e: + logging.error(f"应用启动失败: {e}") + if 'scheduler' in locals(): + scheduler.shutdown() + exit(1) + finally: + if 'scheduler' in locals(): + scheduler.shutdown() + +# 请求日志记录 @app.before_request def log_request_info(): - # 记录每次请求的信息,便于调试和监控 - logging.info(f"Request: {request.method} {request.path}") # 记录请求的方式(GET/POST)和路径 + # 记录请求信息,但排除敏感数据 + sanitized_headers = dict(request.headers) + if 'Authorization' in sanitized_headers: + sanitized_headers['Authorization'] = '[FILTERED]' + if 'Cookie' in sanitized_headers: + sanitized_headers['Cookie'] = '[FILTERED]' + + logging.info( + f"Request: {request.method} {request.path}\n" + f"Remote IP: {request.remote_addr}\n" + f"Headers: {sanitized_headers}" + ) diff --git a/utils/errorResponse.py b/utils/errorResponse.py index d24badb..80b27b7 100644 --- a/utils/errorResponse.py +++ b/utils/errorResponse.py @@ -1,3 +1,40 @@ -from flask import render_template -def errorResponse(errorMsg): - return render_template('error.html',errorMsg=errorMsg) \ No newline at end of file +from flask import render_template, jsonify +import bleach +import re + +def sanitize_error_message(message): + """ + 清理和验证错误消息 + """ + if not message: + return "发生未知错误" + + # 移除任何敏感信息 + message = re.sub(r'(password|token|key|secret)=[\w\-]+', r'\1=[FILTERED]', str(message)) + + # 清理HTML和特殊字符 + message = bleach.clean(message, strip=True) + + # 限制消息长度 + return message[:200] if len(message) > 200 else message + +def errorResponse(errorMsg, status_code=400): + """ + 统一的错误响应处理 + :param errorMsg: 错误消息 + :param status_code: HTTP状态码 + :return: 错误响应 + """ + safe_message = sanitize_error_message(errorMsg) + + if 'application/json' in request.headers.get('Accept', ''): + return jsonify({ + 'success': False, + 'error': safe_message + }), status_code + + return render_template( + 'error.html', + errorMsg=safe_message, + status_code=status_code + ), status_code \ No newline at end of file diff --git a/views/page/page.py b/views/page/page.py index 26449c0..6f41065 100644 --- a/views/page/page.py +++ b/views/page/page.py @@ -1,4 +1,4 @@ -from flask import Flask, session, render_template, redirect, Blueprint, request, jsonify +from flask import Flask, session, render_template, redirect, Blueprint, request, jsonify, abort from utils.mynlp import SnowNLP from utils.getHomePageData import * from utils.getHotWordPageData import * @@ -16,12 +16,60 @@ from sqlalchemy import create_engine import asyncio import torch from BCAT_front.predict import model_manager +from functools import wraps +import bleach +import re +from datetime import datetime, timedelta pb = Blueprint('page', __name__, url_prefix='/page', template_folder='templates') +def sanitize_input(text): + """清理用户输入,防止XSS攻击""" + if text is None: + return None + return bleach.clean(str(text), strip=True) + +def validate_csrf_token(): + """验证CSRF令牌""" + token = request.form.get('csrf_token') + stored_token = session.get('csrf_token') + if not token or not stored_token or token != stored_token: + return False + return True + +def login_required(f): + @wraps(f) + def decorated_function(*args, **kwargs): + if 'username' not in session: + return redirect('/user/login') + return f(*args, **kwargs) + return decorated_function + +def api_login_required(f): + @wraps(f) + def decorated_function(*args, **kwargs): + if 'username' not in session: + return jsonify({'error': 'Unauthorized'}), 401 + return f(*args, **kwargs) + return decorated_function + +def rate_limit(f): + @wraps(f) + def decorated_function(*args, **kwargs): + key = f"rate_limit:{request.remote_addr}:{f.__name__}" + current = int(redis_client.get(key) or 0) + if current >= 100: # 每分钟100次请求限制 + return jsonify({'error': 'Too many requests'}), 429 + pipe = redis_client.pipeline() + pipe.incr(key) + pipe.expire(key, 60) # 60秒后重置 + pipe.execute() + return f(*args, **kwargs) + return decorated_function + # 设置设备 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -37,14 +85,22 @@ except Exception as e: logging.error(f"模型加载失败: {e}") # 数据库配置 -DATABASE_URL = "sqlite:///ai_analysis.db" +DATABASE_URL = os.getenv('DATABASE_URL', "sqlite:///ai_analysis.db") engine = create_engine(DATABASE_URL) AIAnalysis.metadata.create_all(engine) def predict_sentiment(text): """使用改进版模型预测单个文本的情感""" try: - predictions, probabilities = model_manager.predict_batch([text]) + if not text or len(text.strip()) == 0: + return None, None + + # 清理输入 + cleaned_text = sanitize_input(text) + if not cleaned_text: + return None, None + + predictions, probabilities = model_manager.predict_batch([cleaned_text]) if predictions is not None and len(predictions) > 0: return predictions[0], probabilities[0][predictions[0]] return None, None @@ -53,55 +109,70 @@ def predict_sentiment(text): return None, None @pb.route('/home') +@login_required def home(): - username = session.get('username') - articleLenMax, likeCountMaxAuthorName, cityMax = getHomeTagsData() - commentsLikeCountTopFore = getHomeCommentsLikeCountTopFore() - X, Y = getHomeArticleCreatedAtChart() - typeChart = getHomeTypeChart() - createAtChart = getHomeCommentCreatedChart() - # getUserNameWordCloud() - return render_template('index.html', - username=username, - articleLenMax=articleLenMax, - likeCountMaxAuthorName=likeCountMaxAuthorName, - cityMax=cityMax, - commentsLikeCountTopFore=commentsLikeCountTopFore, - xData=X, - yData=Y, - typeChart=typeChart, - createAtChart=createAtChart) - + try: + username = session.get('username') + articleLenMax, likeCountMaxAuthorName, cityMax = getHomeTagsData() + commentsLikeCountTopFore = getHomeCommentsLikeCountTopFore() + X, Y = getHomeArticleCreatedAtChart() + typeChart = getHomeTypeChart() + createAtChart = getHomeCommentCreatedChart() + + return render_template('index.html', + username=username, + articleLenMax=articleLenMax, + likeCountMaxAuthorName=likeCountMaxAuthorName, + cityMax=cityMax, + commentsLikeCountTopFore=commentsLikeCountTopFore, + xData=X, + yData=Y, + typeChart=typeChart, + createAtChart=createAtChart) + except Exception as e: + logging.error(f"加载首页时发生错误: {e}") + return render_template('error.html', error_message="加载首页失败") @pb.route('/hotWord') +@login_required def hotWord(): - username = session.get('username') - hotWordList = getAllHotWords() - print(hotWordList) - defaultHotWord = hotWordList[0][0] - if request.args.get('hotWord'): - defaultHotWord = request.args.get('hotWord') - hotWordLen = getHotWordLen(defaultHotWord) - X, Y = getHotWordPageCreatedAtCharData(defaultHotWord) - sentences = '' - value = SnowNLP(defaultHotWord).sentiments - if value == 0.5: - sentences = '中性' - elif value > 0.5: - sentences = '正面' - elif value < 0.5: - sentences = '负面' - comments = getCommentFilterData(defaultHotWord) - return render_template('hotWord.html', - username=username, - hotWordList=hotWordList, - defaultHotWord=defaultHotWord, - hotWordLen=hotWordLen, - sentences=sentences, - xData=X, - yData=Y, - comments=comments) - + try: + username = session.get('username') + hotWordList = getAllHotWords() + if not hotWordList: + return render_template('error.html', error_message="无法获取热词列表") + + defaultHotWord = sanitize_input(request.args.get('hotWord', hotWordList[0][0])) + + # 验证热词是否在列表中 + if not any(defaultHotWord in word for word in hotWordList): + return abort(400, "无效的热词") + + hotWordLen = getHotWordLen(defaultHotWord) + X, Y = getHotWordPageCreatedAtCharData(defaultHotWord) + + value = SnowNLP(defaultHotWord).sentiments + if value == 0.5: + sentences = '中性' + elif value > 0.5: + sentences = '正面' + elif value < 0.5: + sentences = '负面' + + comments = getCommentFilterData(defaultHotWord) + + return render_template('hotWord.html', + username=username, + hotWordList=hotWordList, + defaultHotWord=defaultHotWord, + hotWordLen=hotWordLen, + sentences=sentences, + xData=X, + yData=Y, + comments=comments) + except Exception as e: + logging.error(f"加载热词页面时发生错误: {e}") + return render_template('error.html', error_message="加载热词页面失败") @pb.route('/hotTopic') def hotTopic(): @@ -127,18 +198,21 @@ def hotTopic(): yData=Y, comments=comments) - @pb.route('/tableData') +@login_required def tableData(): - username = session.get('username') - defaultFlag = False - if request.args.get('flag'): defaultFlag = True - tableData = getTableDataList(defaultFlag) - return render_template('tableData.html', - username=username, - tableData=tableData, - defaultFlag=defaultFlag) - + try: + username = session.get('username') + defaultFlag = bool(request.args.get('flag', False)) + tableData = getTableDataList(defaultFlag) + + return render_template('tableData.html', + username=username, + tableData=tableData, + defaultFlag=defaultFlag) + except Exception as e: + logging.error(f"加载表格数据时发生错误: {e}") + return render_template('error.html', error_message="加载表格数据失败") @pb.route('/articleChar') def articleChar(): @@ -160,63 +234,89 @@ def articleChar(): x2Data=x2Data, y2Data=y2Data) - @pb.route('/ipChar') +@login_required def ipChar(): - username = session.get('username') - articleRegionData = getIPByArticleRegion() - commentRegionData = getIPByCommentsRegion() - return render_template('ipChar.html', - username=username, - articleRegionData=articleRegionData, - commentRegionData=commentRegionData) - + try: + username = session.get('username') + articleRegionData = getIPByArticleRegion() + commentRegionData = getIPByCommentsRegion() + + return render_template('ipChar.html', + username=username, + articleRegionData=articleRegionData, + commentRegionData=commentRegionData) + except Exception as e: + logging.error(f"加载IP统计时发生错误: {e}") + return render_template('error.html', error_message="加载IP统计失败") @pb.route('/commentChar') +@login_required def commentChar(): - username = session.get('username') - X, Y = getCommentDataOne() - genderPieData = getCommentDataTwo() - return render_template('commentChar.html', - username=username, - xData=X, - yData=Y, - genderPieData=genderPieData) - + try: + username = session.get('username') + X, Y = getCommentDataOne() + genderPieData = getCommentDataTwo() + + return render_template('commentChar.html', + username=username, + xData=X, + yData=Y, + genderPieData=genderPieData) + except Exception as e: + logging.error(f"加载评论统计时发生错误: {e}") + return render_template('error.html', error_message="加载评论统计失败") @pb.route('/yuqingChar') +@login_required def yuqingChar(): - username = session.get('username') - # 获取模型选择参数 - model_type = request.args.get('model', 'pro') # 默认使用改进模型 - - X, Y, biedata = getYuQingCharDataOne() - biedata1, biedata2 = getYuQingCharDataTwo(model_type) - x1Data, y1Data = getYuQingCharDataThree() - return render_template('yuqingChar.html', - username=username, - xData=X, - yData=Y, - biedata=biedata, - biedata1=biedata1, - biedata2=biedata2, - x1Data=x1Data, - y1Data=y1Data, - model_type=model_type) + try: + username = session.get('username') + model_type = sanitize_input(request.args.get('model', 'pro')) + + # 验证模型类型 + if model_type not in ['pro', 'basic']: + return abort(400, "无效的模型类型") + + X, Y, biedata = getYuQingCharDataOne() + biedata1, biedata2 = getYuQingCharDataTwo(model_type) + x1Data, y1Data = getYuQingCharDataThree() + + return render_template('yuqingChar.html', + username=username, + xData=X, + yData=Y, + biedata=biedata, + biedata1=biedata1, + biedata2=biedata2, + x1Data=x1Data, + y1Data=y1Data, + model_type=model_type) + except Exception as e: + logging.error(f"加载舆情统计时发生错误: {e}") + return render_template('error.html', error_message="加载舆情统计失败") @pb.route('/yuqingpredict') +@login_required def yuqingpredict(): try: username = session.get('username') TopicList = getAllTopicData() - defaultTopic = TopicList[0][0] - if request.args.get('Topic'): - defaultTopic = request.args.get('Topic') + if not TopicList: + return render_template('error.html', error_message="无法获取话题列表") + + defaultTopic = sanitize_input(request.args.get('Topic', TopicList[0][0])) + + # 验证话题是否在列表中 + if not any(defaultTopic in topic for topic in TopicList): + return abort(400, "无效的话题") + TopicLen = getTopicLen(defaultTopic) X, Y = getTopicCreatedAtandpredictData(defaultTopic) - # 获取模型选择参数 - model_type = request.args.get('model', 'pro') # 默认使用改进模型 + model_type = sanitize_input(request.args.get('model', 'pro')) + if model_type not in ['pro', 'basic']: + return abort(400, "无效的模型类型") # 尝试从缓存获取预测结果 cache_key = f"{defaultTopic}_{model_type}" @@ -226,7 +326,6 @@ def yuqingpredict(): sentences = cached_result else: if model_type == 'basic': - # 使用基础模型(SnowNLP) value = SnowNLP(defaultTopic).sentiments if value == 0.5: sentences = '中性' @@ -235,7 +334,6 @@ def yuqingpredict(): elif value < 0.5: sentences = '负面' else: - # 使用改进模型 predicted_label, confidence = predict_sentiment(defaultTopic) if predicted_label is not None: sentences = '良好' if predicted_label == 0 else '不良' @@ -248,26 +346,30 @@ def yuqingpredict(): prediction_cache.set(cache_key, sentences) comments = getCommentFilterDataTopic(defaultTopic) + return render_template('yuqingpredict.html', - username=username, - hotWordList=TopicList, - defaultHotWord=defaultTopic, - hotWordLen=TopicLen, - sentences=sentences, - xData=X, - yData=Y, - comments=comments, - model_type=model_type) + username=username, + TopicList=TopicList, + defaultTopic=defaultTopic, + TopicLen=TopicLen, + sentences=sentences, + xData=X, + yData=Y, + comments=comments, + model_type=model_type) except Exception as e: - logging.error(f"舆情预测页面渲染失败: {e}") - return render_template('error.html', error_message="加载舆情预测页面失败,请稍后重试") - + logging.error(f"加载舆情预测时发生错误: {e}") + return render_template('error.html', error_message="加载舆情预测失败") @pb.route('/articleCloud') +@login_required def articleCloud(): - username = session.get('username') - return render_template('articleContentCloud.html', username=username) - + try: + username = session.get('username') + return render_template('articleContentCloud.html', username=username) + except Exception as e: + logging.error(f"加载文章云图时发生错误: {e}") + return render_template('error.html', error_message="加载文章云图失败") @pb.route('/page/index') def index(): @@ -306,15 +408,28 @@ def articleChar(id): return render_template('error.html', error_message="加载文章详情失败") @pb.route('/api/analyze_messages', methods=['POST']) +@api_login_required +@rate_limit async def analyze_messages(): try: - # 获取请求参数 + if not validate_csrf_token(): + return jsonify({'error': 'Invalid CSRF token'}), 403 + data = request.get_json() - batch_size = data.get('batch_size', 50) - model_type = data.get('model_type', 'gpt-3.5-turbo') - analysis_depth = data.get('analysis_depth', 'standard') + if not data: + return jsonify({'error': 'No data provided'}), 400 + + batch_size = min(int(data.get('batch_size', 50)), 100) # 限制批量大小 + model_type = sanitize_input(data.get('model_type', 'gpt-3.5-turbo')) + analysis_depth = sanitize_input(data.get('analysis_depth', 'standard')) + + # 验证参数 + if model_type not in ['gpt-3.5-turbo', 'gpt-4']: + return jsonify({'error': 'Invalid model type'}), 400 + + if analysis_depth not in ['basic', 'standard', 'deep']: + return jsonify({'error': 'Invalid analysis depth'}), 400 - # 获取最近的消息 messages = getRecentMessages(batch_size) if not messages: return jsonify({ @@ -322,7 +437,6 @@ async def analyze_messages(): 'error': '没有找到需要分析的消息' }), 404 - # 调用AI进行分析 analysis_results = await ai_analyzer.analyze_messages( messages=messages, batch_size=batch_size, @@ -336,22 +450,27 @@ async def analyze_messages(): 'error': '分析过程中出现错误' }), 500 - # 保存到数据库 - with Session(engine) as session: - for result in analysis_results: - analysis = AIAnalysis( - message_id=result['message_id'], - sentiment=result['sentiment'], - sentiment_score=float(result['sentiment_score']), - keywords=result['keywords'], - key_points=result['key_points'], - influence_analysis=result['influence_analysis'], - risk_level=result['risk_level'] - ) - session.add(analysis) - session.commit() + try: + with Session(engine) as session: + for result in analysis_results: + analysis = AIAnalysis( + message_id=result['message_id'], + sentiment=result['sentiment'], + sentiment_score=float(result['sentiment_score']), + keywords=result['keywords'], + key_points=result['key_points'], + influence_analysis=result['influence_analysis'], + risk_level=result['risk_level'] + ) + session.add(analysis) + session.commit() + except Exception as e: + logging.error(f"保存分析结果时出错: {e}") + return jsonify({ + 'success': False, + 'error': '保存分析结果失败' + }), 500 - # 格式化结果用于显示 display_results = [ ai_analyzer.format_analysis_for_display(result) for result in analysis_results @@ -359,27 +478,25 @@ async def analyze_messages(): return jsonify({ 'success': True, - 'data': display_results, - 'meta': { - 'total_messages': len(messages), - 'analyzed_messages': len(analysis_results), - 'batch_size': batch_size, - 'model_type': model_type, - 'analysis_depth': analysis_depth - } + 'data': display_results }) - + except Exception as e: - logging.error(f"AI分析过程出错: {e}") + logging.error(f"分析消息时发生错误: {e}") return jsonify({ 'success': False, 'error': str(e) }), 500 @pb.route('/api/get_analysis/') +@api_login_required +@rate_limit def get_message_analysis(message_id): """获取特定消息的分析结果""" try: + if not message_id or message_id < 1: + return jsonify({'error': 'Invalid message ID'}), 400 + with Session(engine) as session: analysis = session.query(AIAnalysis)\ .filter(AIAnalysis.message_id == message_id)\