Optimize code structure and enhance security features.
This commit is contained in:
@@ -1,19 +1,17 @@
|
|||||||
import os
|
import os
|
||||||
import re
|
|
||||||
import getpass
|
import getpass
|
||||||
import pymysql
|
import pymysql
|
||||||
import subprocess
|
import subprocess
|
||||||
from flask import Flask, session, request, redirect, render_template, jsonify
|
from flask import Flask, session, request, redirect
|
||||||
from apscheduler.schedulers.background import BackgroundScheduler
|
from apscheduler.schedulers.background import BackgroundScheduler
|
||||||
from pytz import utc
|
from pytz import utc
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
import time
|
|
||||||
from utils.logger import app_logger as logging
|
|
||||||
from utils.db_manager import DatabaseManager
|
|
||||||
import secrets
|
import secrets
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from functools import wraps
|
from utils.logger import app_logger as logging
|
||||||
import bleach
|
from utils.db_pool import DatabasePool
|
||||||
|
from utils.error_handlers import register_error_handlers
|
||||||
|
from middleware.security import set_secure_headers, log_request_info, require_https
|
||||||
|
|
||||||
# 加载环境变量
|
# 加载环境变量
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
@@ -56,21 +54,6 @@ def get_db_connection_interactive():
|
|||||||
logging.error(f"数据库连接失败: {e}")
|
logging.error(f"数据库连接失败: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def sanitize_input(text):
|
|
||||||
"""清理用户输入,防止XSS攻击"""
|
|
||||||
if text is None:
|
|
||||||
return None
|
|
||||||
return bleach.clean(str(text), strip=True)
|
|
||||||
|
|
||||||
def set_secure_headers(response):
|
|
||||||
"""设置安全响应头"""
|
|
||||||
response.headers['X-Content-Type-Options'] = 'nosniff'
|
|
||||||
response.headers['X-Frame-Options'] = 'SAMEORIGIN'
|
|
||||||
response.headers['X-XSS-Protection'] = '1; mode=block'
|
|
||||||
response.headers['Strict-Transport-Security'] = 'max-age=31536000; includeSubDomains'
|
|
||||||
response.headers['Content-Security-Policy'] = "default-src 'self'"
|
|
||||||
return response
|
|
||||||
|
|
||||||
# 初始化 Flask 应用
|
# 初始化 Flask 应用
|
||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
app.secret_key = os.getenv('FLASK_SECRET_KEY', secrets.token_hex(32))
|
app.secret_key = os.getenv('FLASK_SECRET_KEY', secrets.token_hex(32))
|
||||||
@@ -87,11 +70,15 @@ from views.workflow_api import workflow_bp, workflow_api_bp
|
|||||||
app.register_blueprint(page.pb)
|
app.register_blueprint(page.pb)
|
||||||
app.register_blueprint(user.ub)
|
app.register_blueprint(user.ub)
|
||||||
app.register_blueprint(spider_bp)
|
app.register_blueprint(spider_bp)
|
||||||
app.register_blueprint(workflow_bp) # 注册工作流蓝图
|
app.register_blueprint(workflow_bp)
|
||||||
app.register_blueprint(workflow_api_bp) # 注册工作流API蓝图
|
app.register_blueprint(workflow_api_bp)
|
||||||
|
|
||||||
|
# 注册错误处理器
|
||||||
|
register_error_handlers(app)
|
||||||
|
|
||||||
# 首页路由
|
# 首页路由
|
||||||
@app.route('/')
|
@app.route('/')
|
||||||
|
@require_https()
|
||||||
def hello_world():
|
def hello_world():
|
||||||
session.clear()
|
session.clear()
|
||||||
return redirect('/user/login')
|
return redirect('/user/login')
|
||||||
@@ -99,11 +86,9 @@ def hello_world():
|
|||||||
# 请求前中间件
|
# 请求前中间件
|
||||||
@app.before_request
|
@app.before_request
|
||||||
def before_request():
|
def before_request():
|
||||||
# 检查是否是HTTPS
|
# 记录请求信息
|
||||||
if not request.is_secure and not app.debug:
|
log_request_info()
|
||||||
url = request.url.replace('http://', 'https://', 1)
|
|
||||||
return redirect(url, code=301)
|
|
||||||
|
|
||||||
# 如果请求的是静态文件路径,允许访问
|
# 如果请求的是静态文件路径,允许访问
|
||||||
if request.path.startswith('/static'):
|
if request.path.startswith('/static'):
|
||||||
return
|
return
|
||||||
@@ -138,35 +123,6 @@ def before_request():
|
|||||||
def after_request(response):
|
def after_request(response):
|
||||||
return set_secure_headers(response)
|
return set_secure_headers(response)
|
||||||
|
|
||||||
# 错误处理
|
|
||||||
@app.errorhandler(404)
|
|
||||||
def not_found_error(error):
|
|
||||||
return render_template('404.html'), 404
|
|
||||||
|
|
||||||
@app.errorhandler(500)
|
|
||||||
def internal_error(error):
|
|
||||||
return render_template('error.html',
|
|
||||||
error_code=500,
|
|
||||||
error_title='服务器错误',
|
|
||||||
error_message='服务器遇到了一个问题,请稍后再试。',
|
|
||||||
error_i18n_key='serverError'), 500
|
|
||||||
|
|
||||||
@app.errorhandler(403)
|
|
||||||
def forbidden_error(error):
|
|
||||||
return render_template('error.html',
|
|
||||||
error_code=403,
|
|
||||||
error_title='禁止访问',
|
|
||||||
error_message='您没有权限访问此页面。',
|
|
||||||
error_i18n_key='forbidden'), 403
|
|
||||||
|
|
||||||
@app.errorhandler(400)
|
|
||||||
def bad_request_error(error):
|
|
||||||
return render_template('error.html',
|
|
||||||
error_code=400,
|
|
||||||
error_title='错误请求',
|
|
||||||
error_message='服务器无法理解您的请求。',
|
|
||||||
error_i18n_key='badRequest'), 400
|
|
||||||
|
|
||||||
# 数据库配置
|
# 数据库配置
|
||||||
DB_CONFIG = {
|
DB_CONFIG = {
|
||||||
'host': os.getenv('DB_HOST', 'localhost'),
|
'host': os.getenv('DB_HOST', 'localhost'),
|
||||||
@@ -178,9 +134,6 @@ DB_CONFIG = {
|
|||||||
'ssl': {'ca': os.getenv('DB_SSL_CA')} if os.getenv('DB_SSL_CA') else None
|
'ssl': {'ca': os.getenv('DB_SSL_CA')} if os.getenv('DB_SSL_CA') else None
|
||||||
}
|
}
|
||||||
|
|
||||||
# 初始化数据库管理器
|
|
||||||
DatabaseManager.initialize(DB_CONFIG)
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# 检测是否需要初始化数据库
|
# 检测是否需要初始化数据库
|
||||||
try:
|
try:
|
||||||
@@ -194,6 +147,13 @@ if __name__ == '__main__':
|
|||||||
logging.error(f"数据库初始化失败: {e}")
|
logging.error(f"数据库初始化失败: {e}")
|
||||||
exit(1)
|
exit(1)
|
||||||
|
|
||||||
|
# 初始化数据库连接池
|
||||||
|
try:
|
||||||
|
DatabasePool.initialize(DB_CONFIG)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"数据库连接池初始化失败: {e}")
|
||||||
|
exit(1)
|
||||||
|
|
||||||
# 设置定时任务
|
# 设置定时任务
|
||||||
try:
|
try:
|
||||||
scheduler = BackgroundScheduler(timezone=utc)
|
scheduler = BackgroundScheduler(timezone=utc)
|
||||||
@@ -222,23 +182,9 @@ if __name__ == '__main__':
|
|||||||
logging.error(f"应用启动失败: {e}")
|
logging.error(f"应用启动失败: {e}")
|
||||||
if 'scheduler' in locals():
|
if 'scheduler' in locals():
|
||||||
scheduler.shutdown()
|
scheduler.shutdown()
|
||||||
|
DatabasePool.close()
|
||||||
exit(1)
|
exit(1)
|
||||||
finally:
|
finally:
|
||||||
if 'scheduler' in locals():
|
if 'scheduler' in locals():
|
||||||
scheduler.shutdown()
|
scheduler.shutdown()
|
||||||
|
DatabasePool.close()
|
||||||
# 请求日志记录
|
|
||||||
@app.before_request
|
|
||||||
def log_request_info():
|
|
||||||
# 记录请求信息,但排除敏感数据
|
|
||||||
sanitized_headers = dict(request.headers)
|
|
||||||
if 'Authorization' in sanitized_headers:
|
|
||||||
sanitized_headers['Authorization'] = '[FILTERED]'
|
|
||||||
if 'Cookie' in sanitized_headers:
|
|
||||||
sanitized_headers['Cookie'] = '[FILTERED]'
|
|
||||||
|
|
||||||
logging.info(
|
|
||||||
f"Request: {request.method} {request.path}\n"
|
|
||||||
f"Remote IP: {request.remote_addr}\n"
|
|
||||||
f"Headers: {sanitized_headers}"
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -0,0 +1,45 @@
|
|||||||
|
from flask import request, redirect
|
||||||
|
from functools import wraps
|
||||||
|
import bleach
|
||||||
|
from utils.logger import app_logger as logging
|
||||||
|
|
||||||
|
def sanitize_input(text):
|
||||||
|
"""清理用户输入,防止XSS攻击"""
|
||||||
|
if text is None:
|
||||||
|
return None
|
||||||
|
return bleach.clean(str(text), strip=True)
|
||||||
|
|
||||||
|
def set_secure_headers(response):
|
||||||
|
"""设置安全响应头"""
|
||||||
|
response.headers['X-Content-Type-Options'] = 'nosniff'
|
||||||
|
response.headers['X-Frame-Options'] = 'SAMEORIGIN'
|
||||||
|
response.headers['X-XSS-Protection'] = '1; mode=block'
|
||||||
|
response.headers['Strict-Transport-Security'] = 'max-age=31536000; includeSubDomains'
|
||||||
|
response.headers['Content-Security-Policy'] = "default-src 'self'; script-src 'self' 'unsafe-inline' 'unsafe-eval'; style-src 'self' 'unsafe-inline';"
|
||||||
|
return response
|
||||||
|
|
||||||
|
def require_https():
|
||||||
|
"""强制HTTPS中间件"""
|
||||||
|
def decorator(f):
|
||||||
|
@wraps(f)
|
||||||
|
def decorated_function(*args, **kwargs):
|
||||||
|
if not request.is_secure and not request.is_localhost:
|
||||||
|
url = request.url.replace('http://', 'https://', 1)
|
||||||
|
return redirect(url, code=301)
|
||||||
|
return f(*args, **kwargs)
|
||||||
|
return decorated_function
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
def log_request_info():
|
||||||
|
"""请求日志记录中间件"""
|
||||||
|
sanitized_headers = dict(request.headers)
|
||||||
|
if 'Authorization' in sanitized_headers:
|
||||||
|
sanitized_headers['Authorization'] = '[FILTERED]'
|
||||||
|
if 'Cookie' in sanitized_headers:
|
||||||
|
sanitized_headers['Cookie'] = '[FILTERED]'
|
||||||
|
|
||||||
|
logging.info(
|
||||||
|
f"Request: {request.method} {request.path}\n"
|
||||||
|
f"Remote IP: {request.remote_addr}\n"
|
||||||
|
f"Headers: {sanitized_headers}"
|
||||||
|
)
|
||||||
@@ -88,3 +88,5 @@ xz=5.4.6=h8cc25b3_1
|
|||||||
zipp=3.17.0=py38haa95532_0
|
zipp=3.17.0=py38haa95532_0
|
||||||
zlib=1.2.13=h8cc25b3_1
|
zlib=1.2.13=h8cc25b3_1
|
||||||
zstd=1.5.5=hd43e919_2
|
zstd=1.5.5=hd43e919_2
|
||||||
|
DBUtils==3.0.2
|
||||||
|
bleach==6.1.0
|
||||||
|
|||||||
@@ -0,0 +1,50 @@
|
|||||||
|
import pymysql
|
||||||
|
from pymysql.cursors import DictCursor
|
||||||
|
from dbutils.pooled_db import PooledDB
|
||||||
|
from utils.logger import app_logger as logging
|
||||||
|
|
||||||
|
class DatabasePool:
|
||||||
|
_pool = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def initialize(cls, db_config):
|
||||||
|
"""初始化数据库连接池"""
|
||||||
|
try:
|
||||||
|
cls._pool = PooledDB(
|
||||||
|
creator=pymysql,
|
||||||
|
maxconnections=10,
|
||||||
|
mincached=2,
|
||||||
|
maxcached=5,
|
||||||
|
maxshared=3,
|
||||||
|
blocking=True,
|
||||||
|
maxusage=None,
|
||||||
|
setsession=[],
|
||||||
|
ping=0,
|
||||||
|
host=db_config['host'],
|
||||||
|
port=db_config['port'],
|
||||||
|
user=db_config['user'],
|
||||||
|
password=db_config['password'],
|
||||||
|
database=db_config['database'],
|
||||||
|
charset=db_config['charset'],
|
||||||
|
cursorclass=DictCursor,
|
||||||
|
ssl=db_config.get('ssl')
|
||||||
|
)
|
||||||
|
logging.info("数据库连接池初始化成功")
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"数据库连接池初始化失败: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_connection(cls):
|
||||||
|
"""获取数据库连接"""
|
||||||
|
if cls._pool is None:
|
||||||
|
raise Exception("数据库连接池未初始化")
|
||||||
|
return cls._pool.connection()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def close(cls):
|
||||||
|
"""关闭数据库连接池"""
|
||||||
|
if cls._pool:
|
||||||
|
cls._pool._pool.close()
|
||||||
|
cls._pool = None
|
||||||
|
logging.info("数据库连接池已关闭")
|
||||||
@@ -0,0 +1,46 @@
|
|||||||
|
from flask import render_template
|
||||||
|
from utils.logger import app_logger as logging
|
||||||
|
|
||||||
|
def register_error_handlers(app):
|
||||||
|
"""注册错误处理器"""
|
||||||
|
|
||||||
|
@app.errorhandler(404)
|
||||||
|
def not_found_error(error):
|
||||||
|
logging.warning(f"404错误: {request.url}")
|
||||||
|
return render_template('404.html'), 404
|
||||||
|
|
||||||
|
@app.errorhandler(500)
|
||||||
|
def internal_error(error):
|
||||||
|
logging.error(f"500错误: {error}")
|
||||||
|
return render_template('error.html',
|
||||||
|
error_code=500,
|
||||||
|
error_title='服务器错误',
|
||||||
|
error_message='服务器遇到了一个问题,请稍后再试。',
|
||||||
|
error_i18n_key='serverError'), 500
|
||||||
|
|
||||||
|
@app.errorhandler(403)
|
||||||
|
def forbidden_error(error):
|
||||||
|
logging.warning(f"403错误: {request.url}")
|
||||||
|
return render_template('error.html',
|
||||||
|
error_code=403,
|
||||||
|
error_title='禁止访问',
|
||||||
|
error_message='您没有权限访问此页面。',
|
||||||
|
error_i18n_key='forbidden'), 403
|
||||||
|
|
||||||
|
@app.errorhandler(400)
|
||||||
|
def bad_request_error(error):
|
||||||
|
logging.warning(f"400错误: {error}")
|
||||||
|
return render_template('error.html',
|
||||||
|
error_code=400,
|
||||||
|
error_title='错误请求',
|
||||||
|
error_message='服务器无法理解您的请求。',
|
||||||
|
error_i18n_key='badRequest'), 400
|
||||||
|
|
||||||
|
@app.errorhandler(Exception)
|
||||||
|
def handle_exception(error):
|
||||||
|
logging.error(f"未处理的异常: {error}")
|
||||||
|
return render_template('error.html',
|
||||||
|
error_code=500,
|
||||||
|
error_title='系统错误',
|
||||||
|
error_message='系统发生了一个未预期的错误。',
|
||||||
|
error_i18n_key='unexpectedError'), 500
|
||||||
Reference in New Issue
Block a user