From 1bbdfcd96b0c190ade09deb46a8dde7b82ae0947 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=88=92=E9=85=92=E7=9A=84=E6=9D=8E=E7=99=BD?= <670939375@qq.com> Date: Wed, 2 Apr 2025 20:07:16 +0800 Subject: [PATCH] Optimize code structure and enhance security features. --- app.py | 100 +++++++++------------------------------- middleware/security.py | 45 ++++++++++++++++++ requirements.txt | 2 + utils/db_pool.py | 50 ++++++++++++++++++++ utils/error_handlers.py | 46 ++++++++++++++++++ 5 files changed, 166 insertions(+), 77 deletions(-) create mode 100644 middleware/security.py create mode 100644 utils/db_pool.py create mode 100644 utils/error_handlers.py diff --git a/app.py b/app.py index 391230b..fe1ee4d 100644 --- a/app.py +++ b/app.py @@ -1,19 +1,17 @@ import os -import re import getpass import pymysql import subprocess -from flask import Flask, session, request, redirect, render_template, jsonify +from flask import Flask, session, request, redirect from apscheduler.schedulers.background import BackgroundScheduler from pytz import utc from datetime import datetime, timedelta -import time -from utils.logger import app_logger as logging -from utils.db_manager import DatabaseManager import secrets from dotenv import load_dotenv -from functools import wraps -import bleach +from utils.logger import app_logger as logging +from utils.db_pool import DatabasePool +from utils.error_handlers import register_error_handlers +from middleware.security import set_secure_headers, log_request_info, require_https # 加载环境变量 load_dotenv() @@ -56,21 +54,6 @@ def get_db_connection_interactive(): logging.error(f"数据库连接失败: {e}") raise -def sanitize_input(text): - """清理用户输入,防止XSS攻击""" - if text is None: - return None - return bleach.clean(str(text), strip=True) - -def set_secure_headers(response): - """设置安全响应头""" - response.headers['X-Content-Type-Options'] = 'nosniff' - response.headers['X-Frame-Options'] = 'SAMEORIGIN' - response.headers['X-XSS-Protection'] = '1; mode=block' - response.headers['Strict-Transport-Security'] = 'max-age=31536000; includeSubDomains' - response.headers['Content-Security-Policy'] = "default-src 'self'" - return response - # 初始化 Flask 应用 app = Flask(__name__) app.secret_key = os.getenv('FLASK_SECRET_KEY', secrets.token_hex(32)) @@ -87,11 +70,15 @@ from views.workflow_api import workflow_bp, workflow_api_bp app.register_blueprint(page.pb) app.register_blueprint(user.ub) app.register_blueprint(spider_bp) -app.register_blueprint(workflow_bp) # 注册工作流蓝图 -app.register_blueprint(workflow_api_bp) # 注册工作流API蓝图 +app.register_blueprint(workflow_bp) +app.register_blueprint(workflow_api_bp) + +# 注册错误处理器 +register_error_handlers(app) # 首页路由 @app.route('/') +@require_https() def hello_world(): session.clear() return redirect('/user/login') @@ -99,11 +86,9 @@ def hello_world(): # 请求前中间件 @app.before_request def before_request(): - # 检查是否是HTTPS - if not request.is_secure and not app.debug: - url = request.url.replace('http://', 'https://', 1) - return redirect(url, code=301) - + # 记录请求信息 + log_request_info() + # 如果请求的是静态文件路径,允许访问 if request.path.startswith('/static'): return @@ -138,35 +123,6 @@ def before_request(): def after_request(response): return set_secure_headers(response) -# 错误处理 -@app.errorhandler(404) -def not_found_error(error): - return render_template('404.html'), 404 - -@app.errorhandler(500) -def internal_error(error): - return render_template('error.html', - error_code=500, - error_title='服务器错误', - error_message='服务器遇到了一个问题,请稍后再试。', - error_i18n_key='serverError'), 500 - -@app.errorhandler(403) -def forbidden_error(error): - return render_template('error.html', - error_code=403, - error_title='禁止访问', - error_message='您没有权限访问此页面。', - error_i18n_key='forbidden'), 403 - -@app.errorhandler(400) -def bad_request_error(error): - return render_template('error.html', - error_code=400, - error_title='错误请求', - error_message='服务器无法理解您的请求。', - error_i18n_key='badRequest'), 400 - # 数据库配置 DB_CONFIG = { 'host': os.getenv('DB_HOST', 'localhost'), @@ -178,9 +134,6 @@ DB_CONFIG = { 'ssl': {'ca': os.getenv('DB_SSL_CA')} if os.getenv('DB_SSL_CA') else None } -# 初始化数据库管理器 -DatabaseManager.initialize(DB_CONFIG) - if __name__ == '__main__': # 检测是否需要初始化数据库 try: @@ -194,6 +147,13 @@ if __name__ == '__main__': logging.error(f"数据库初始化失败: {e}") exit(1) + # 初始化数据库连接池 + try: + DatabasePool.initialize(DB_CONFIG) + except Exception as e: + logging.error(f"数据库连接池初始化失败: {e}") + exit(1) + # 设置定时任务 try: scheduler = BackgroundScheduler(timezone=utc) @@ -222,23 +182,9 @@ if __name__ == '__main__': logging.error(f"应用启动失败: {e}") if 'scheduler' in locals(): scheduler.shutdown() + DatabasePool.close() exit(1) finally: if 'scheduler' in locals(): scheduler.shutdown() - -# 请求日志记录 -@app.before_request -def log_request_info(): - # 记录请求信息,但排除敏感数据 - sanitized_headers = dict(request.headers) - if 'Authorization' in sanitized_headers: - sanitized_headers['Authorization'] = '[FILTERED]' - if 'Cookie' in sanitized_headers: - sanitized_headers['Cookie'] = '[FILTERED]' - - logging.info( - f"Request: {request.method} {request.path}\n" - f"Remote IP: {request.remote_addr}\n" - f"Headers: {sanitized_headers}" - ) + DatabasePool.close() diff --git a/middleware/security.py b/middleware/security.py new file mode 100644 index 0000000..1fdcfb6 --- /dev/null +++ b/middleware/security.py @@ -0,0 +1,45 @@ +from flask import request, redirect +from functools import wraps +import bleach +from utils.logger import app_logger as logging + +def sanitize_input(text): + """清理用户输入,防止XSS攻击""" + if text is None: + return None + return bleach.clean(str(text), strip=True) + +def set_secure_headers(response): + """设置安全响应头""" + response.headers['X-Content-Type-Options'] = 'nosniff' + response.headers['X-Frame-Options'] = 'SAMEORIGIN' + response.headers['X-XSS-Protection'] = '1; mode=block' + response.headers['Strict-Transport-Security'] = 'max-age=31536000; includeSubDomains' + response.headers['Content-Security-Policy'] = "default-src 'self'; script-src 'self' 'unsafe-inline' 'unsafe-eval'; style-src 'self' 'unsafe-inline';" + return response + +def require_https(): + """强制HTTPS中间件""" + def decorator(f): + @wraps(f) + def decorated_function(*args, **kwargs): + if not request.is_secure and not request.is_localhost: + url = request.url.replace('http://', 'https://', 1) + return redirect(url, code=301) + return f(*args, **kwargs) + return decorated_function + return decorator + +def log_request_info(): + """请求日志记录中间件""" + sanitized_headers = dict(request.headers) + if 'Authorization' in sanitized_headers: + sanitized_headers['Authorization'] = '[FILTERED]' + if 'Cookie' in sanitized_headers: + sanitized_headers['Cookie'] = '[FILTERED]' + + logging.info( + f"Request: {request.method} {request.path}\n" + f"Remote IP: {request.remote_addr}\n" + f"Headers: {sanitized_headers}" + ) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index c06e17d..bc1be3e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -88,3 +88,5 @@ xz=5.4.6=h8cc25b3_1 zipp=3.17.0=py38haa95532_0 zlib=1.2.13=h8cc25b3_1 zstd=1.5.5=hd43e919_2 +DBUtils==3.0.2 +bleach==6.1.0 diff --git a/utils/db_pool.py b/utils/db_pool.py new file mode 100644 index 0000000..62d98b6 --- /dev/null +++ b/utils/db_pool.py @@ -0,0 +1,50 @@ +import pymysql +from pymysql.cursors import DictCursor +from dbutils.pooled_db import PooledDB +from utils.logger import app_logger as logging + +class DatabasePool: + _pool = None + + @classmethod + def initialize(cls, db_config): + """初始化数据库连接池""" + try: + cls._pool = PooledDB( + creator=pymysql, + maxconnections=10, + mincached=2, + maxcached=5, + maxshared=3, + blocking=True, + maxusage=None, + setsession=[], + ping=0, + host=db_config['host'], + port=db_config['port'], + user=db_config['user'], + password=db_config['password'], + database=db_config['database'], + charset=db_config['charset'], + cursorclass=DictCursor, + ssl=db_config.get('ssl') + ) + logging.info("数据库连接池初始化成功") + except Exception as e: + logging.error(f"数据库连接池初始化失败: {e}") + raise + + @classmethod + def get_connection(cls): + """获取数据库连接""" + if cls._pool is None: + raise Exception("数据库连接池未初始化") + return cls._pool.connection() + + @classmethod + def close(cls): + """关闭数据库连接池""" + if cls._pool: + cls._pool._pool.close() + cls._pool = None + logging.info("数据库连接池已关闭") \ No newline at end of file diff --git a/utils/error_handlers.py b/utils/error_handlers.py new file mode 100644 index 0000000..a500a18 --- /dev/null +++ b/utils/error_handlers.py @@ -0,0 +1,46 @@ +from flask import render_template +from utils.logger import app_logger as logging + +def register_error_handlers(app): + """注册错误处理器""" + + @app.errorhandler(404) + def not_found_error(error): + logging.warning(f"404错误: {request.url}") + return render_template('404.html'), 404 + + @app.errorhandler(500) + def internal_error(error): + logging.error(f"500错误: {error}") + return render_template('error.html', + error_code=500, + error_title='服务器错误', + error_message='服务器遇到了一个问题,请稍后再试。', + error_i18n_key='serverError'), 500 + + @app.errorhandler(403) + def forbidden_error(error): + logging.warning(f"403错误: {request.url}") + return render_template('error.html', + error_code=403, + error_title='禁止访问', + error_message='您没有权限访问此页面。', + error_i18n_key='forbidden'), 403 + + @app.errorhandler(400) + def bad_request_error(error): + logging.warning(f"400错误: {error}") + return render_template('error.html', + error_code=400, + error_title='错误请求', + error_message='服务器无法理解您的请求。', + error_i18n_key='badRequest'), 400 + + @app.errorhandler(Exception) + def handle_exception(error): + logging.error(f"未处理的异常: {error}") + return render_template('error.html', + error_code=500, + error_title='系统错误', + error_message='系统发生了一个未预期的错误。', + error_i18n_key='unexpectedError'), 500 \ No newline at end of file