Optimize code structure and enhance security features.

This commit is contained in:
戒酒的李白
2025-04-02 20:07:16 +08:00
parent bad84f5476
commit 1bbdfcd96b
5 changed files with 166 additions and 77 deletions
+23 -77
View File
@@ -1,19 +1,17 @@
import os import os
import re
import getpass import getpass
import pymysql import pymysql
import subprocess import subprocess
from flask import Flask, session, request, redirect, render_template, jsonify from flask import Flask, session, request, redirect
from apscheduler.schedulers.background import BackgroundScheduler from apscheduler.schedulers.background import BackgroundScheduler
from pytz import utc from pytz import utc
from datetime import datetime, timedelta from datetime import datetime, timedelta
import time
from utils.logger import app_logger as logging
from utils.db_manager import DatabaseManager
import secrets import secrets
from dotenv import load_dotenv from dotenv import load_dotenv
from functools import wraps from utils.logger import app_logger as logging
import bleach from utils.db_pool import DatabasePool
from utils.error_handlers import register_error_handlers
from middleware.security import set_secure_headers, log_request_info, require_https
# 加载环境变量 # 加载环境变量
load_dotenv() load_dotenv()
@@ -56,21 +54,6 @@ def get_db_connection_interactive():
logging.error(f"数据库连接失败: {e}") logging.error(f"数据库连接失败: {e}")
raise raise
def sanitize_input(text):
"""清理用户输入,防止XSS攻击"""
if text is None:
return None
return bleach.clean(str(text), strip=True)
def set_secure_headers(response):
"""设置安全响应头"""
response.headers['X-Content-Type-Options'] = 'nosniff'
response.headers['X-Frame-Options'] = 'SAMEORIGIN'
response.headers['X-XSS-Protection'] = '1; mode=block'
response.headers['Strict-Transport-Security'] = 'max-age=31536000; includeSubDomains'
response.headers['Content-Security-Policy'] = "default-src 'self'"
return response
# 初始化 Flask 应用 # 初始化 Flask 应用
app = Flask(__name__) app = Flask(__name__)
app.secret_key = os.getenv('FLASK_SECRET_KEY', secrets.token_hex(32)) app.secret_key = os.getenv('FLASK_SECRET_KEY', secrets.token_hex(32))
@@ -87,11 +70,15 @@ from views.workflow_api import workflow_bp, workflow_api_bp
app.register_blueprint(page.pb) app.register_blueprint(page.pb)
app.register_blueprint(user.ub) app.register_blueprint(user.ub)
app.register_blueprint(spider_bp) app.register_blueprint(spider_bp)
app.register_blueprint(workflow_bp) # 注册工作流蓝图 app.register_blueprint(workflow_bp)
app.register_blueprint(workflow_api_bp) # 注册工作流API蓝图 app.register_blueprint(workflow_api_bp)
# 注册错误处理器
register_error_handlers(app)
# 首页路由 # 首页路由
@app.route('/') @app.route('/')
@require_https()
def hello_world(): def hello_world():
session.clear() session.clear()
return redirect('/user/login') return redirect('/user/login')
@@ -99,11 +86,9 @@ def hello_world():
# 请求前中间件 # 请求前中间件
@app.before_request @app.before_request
def before_request(): def before_request():
# 检查是否是HTTPS # 记录请求信息
if not request.is_secure and not app.debug: log_request_info()
url = request.url.replace('http://', 'https://', 1)
return redirect(url, code=301)
# 如果请求的是静态文件路径,允许访问 # 如果请求的是静态文件路径,允许访问
if request.path.startswith('/static'): if request.path.startswith('/static'):
return return
@@ -138,35 +123,6 @@ def before_request():
def after_request(response): def after_request(response):
return set_secure_headers(response) return set_secure_headers(response)
# 错误处理
@app.errorhandler(404)
def not_found_error(error):
return render_template('404.html'), 404
@app.errorhandler(500)
def internal_error(error):
return render_template('error.html',
error_code=500,
error_title='服务器错误',
error_message='服务器遇到了一个问题,请稍后再试。',
error_i18n_key='serverError'), 500
@app.errorhandler(403)
def forbidden_error(error):
return render_template('error.html',
error_code=403,
error_title='禁止访问',
error_message='您没有权限访问此页面。',
error_i18n_key='forbidden'), 403
@app.errorhandler(400)
def bad_request_error(error):
return render_template('error.html',
error_code=400,
error_title='错误请求',
error_message='服务器无法理解您的请求。',
error_i18n_key='badRequest'), 400
# 数据库配置 # 数据库配置
DB_CONFIG = { DB_CONFIG = {
'host': os.getenv('DB_HOST', 'localhost'), 'host': os.getenv('DB_HOST', 'localhost'),
@@ -178,9 +134,6 @@ DB_CONFIG = {
'ssl': {'ca': os.getenv('DB_SSL_CA')} if os.getenv('DB_SSL_CA') else None 'ssl': {'ca': os.getenv('DB_SSL_CA')} if os.getenv('DB_SSL_CA') else None
} }
# 初始化数据库管理器
DatabaseManager.initialize(DB_CONFIG)
if __name__ == '__main__': if __name__ == '__main__':
# 检测是否需要初始化数据库 # 检测是否需要初始化数据库
try: try:
@@ -194,6 +147,13 @@ if __name__ == '__main__':
logging.error(f"数据库初始化失败: {e}") logging.error(f"数据库初始化失败: {e}")
exit(1) exit(1)
# 初始化数据库连接池
try:
DatabasePool.initialize(DB_CONFIG)
except Exception as e:
logging.error(f"数据库连接池初始化失败: {e}")
exit(1)
# 设置定时任务 # 设置定时任务
try: try:
scheduler = BackgroundScheduler(timezone=utc) scheduler = BackgroundScheduler(timezone=utc)
@@ -222,23 +182,9 @@ if __name__ == '__main__':
logging.error(f"应用启动失败: {e}") logging.error(f"应用启动失败: {e}")
if 'scheduler' in locals(): if 'scheduler' in locals():
scheduler.shutdown() scheduler.shutdown()
DatabasePool.close()
exit(1) exit(1)
finally: finally:
if 'scheduler' in locals(): if 'scheduler' in locals():
scheduler.shutdown() scheduler.shutdown()
DatabasePool.close()
# 请求日志记录
@app.before_request
def log_request_info():
# 记录请求信息,但排除敏感数据
sanitized_headers = dict(request.headers)
if 'Authorization' in sanitized_headers:
sanitized_headers['Authorization'] = '[FILTERED]'
if 'Cookie' in sanitized_headers:
sanitized_headers['Cookie'] = '[FILTERED]'
logging.info(
f"Request: {request.method} {request.path}\n"
f"Remote IP: {request.remote_addr}\n"
f"Headers: {sanitized_headers}"
)
+45
View File
@@ -0,0 +1,45 @@
from flask import request, redirect
from functools import wraps
import bleach
from utils.logger import app_logger as logging
def sanitize_input(text):
"""清理用户输入,防止XSS攻击"""
if text is None:
return None
return bleach.clean(str(text), strip=True)
def set_secure_headers(response):
"""设置安全响应头"""
response.headers['X-Content-Type-Options'] = 'nosniff'
response.headers['X-Frame-Options'] = 'SAMEORIGIN'
response.headers['X-XSS-Protection'] = '1; mode=block'
response.headers['Strict-Transport-Security'] = 'max-age=31536000; includeSubDomains'
response.headers['Content-Security-Policy'] = "default-src 'self'; script-src 'self' 'unsafe-inline' 'unsafe-eval'; style-src 'self' 'unsafe-inline';"
return response
def require_https():
"""强制HTTPS中间件"""
def decorator(f):
@wraps(f)
def decorated_function(*args, **kwargs):
if not request.is_secure and not request.is_localhost:
url = request.url.replace('http://', 'https://', 1)
return redirect(url, code=301)
return f(*args, **kwargs)
return decorated_function
return decorator
def log_request_info():
"""请求日志记录中间件"""
sanitized_headers = dict(request.headers)
if 'Authorization' in sanitized_headers:
sanitized_headers['Authorization'] = '[FILTERED]'
if 'Cookie' in sanitized_headers:
sanitized_headers['Cookie'] = '[FILTERED]'
logging.info(
f"Request: {request.method} {request.path}\n"
f"Remote IP: {request.remote_addr}\n"
f"Headers: {sanitized_headers}"
)
+2
View File
@@ -88,3 +88,5 @@ xz=5.4.6=h8cc25b3_1
zipp=3.17.0=py38haa95532_0 zipp=3.17.0=py38haa95532_0
zlib=1.2.13=h8cc25b3_1 zlib=1.2.13=h8cc25b3_1
zstd=1.5.5=hd43e919_2 zstd=1.5.5=hd43e919_2
DBUtils==3.0.2
bleach==6.1.0
+50
View File
@@ -0,0 +1,50 @@
import pymysql
from pymysql.cursors import DictCursor
from dbutils.pooled_db import PooledDB
from utils.logger import app_logger as logging
class DatabasePool:
_pool = None
@classmethod
def initialize(cls, db_config):
"""初始化数据库连接池"""
try:
cls._pool = PooledDB(
creator=pymysql,
maxconnections=10,
mincached=2,
maxcached=5,
maxshared=3,
blocking=True,
maxusage=None,
setsession=[],
ping=0,
host=db_config['host'],
port=db_config['port'],
user=db_config['user'],
password=db_config['password'],
database=db_config['database'],
charset=db_config['charset'],
cursorclass=DictCursor,
ssl=db_config.get('ssl')
)
logging.info("数据库连接池初始化成功")
except Exception as e:
logging.error(f"数据库连接池初始化失败: {e}")
raise
@classmethod
def get_connection(cls):
"""获取数据库连接"""
if cls._pool is None:
raise Exception("数据库连接池未初始化")
return cls._pool.connection()
@classmethod
def close(cls):
"""关闭数据库连接池"""
if cls._pool:
cls._pool._pool.close()
cls._pool = None
logging.info("数据库连接池已关闭")
+46
View File
@@ -0,0 +1,46 @@
from flask import render_template
from utils.logger import app_logger as logging
def register_error_handlers(app):
"""注册错误处理器"""
@app.errorhandler(404)
def not_found_error(error):
logging.warning(f"404错误: {request.url}")
return render_template('404.html'), 404
@app.errorhandler(500)
def internal_error(error):
logging.error(f"500错误: {error}")
return render_template('error.html',
error_code=500,
error_title='服务器错误',
error_message='服务器遇到了一个问题,请稍后再试。',
error_i18n_key='serverError'), 500
@app.errorhandler(403)
def forbidden_error(error):
logging.warning(f"403错误: {request.url}")
return render_template('error.html',
error_code=403,
error_title='禁止访问',
error_message='您没有权限访问此页面。',
error_i18n_key='forbidden'), 403
@app.errorhandler(400)
def bad_request_error(error):
logging.warning(f"400错误: {error}")
return render_template('error.html',
error_code=400,
error_title='错误请求',
error_message='服务器无法理解您的请求。',
error_i18n_key='badRequest'), 400
@app.errorhandler(Exception)
def handle_exception(error):
logging.error(f"未处理的异常: {error}")
return render_template('error.html',
error_code=500,
error_title='系统错误',
error_message='系统发生了一个未预期的错误。',
error_i18n_key='unexpectedError'), 500