Optimize code structure and enhance security features.
This commit is contained in:
@@ -1,19 +1,17 @@
|
||||
import os
|
||||
import re
|
||||
import getpass
|
||||
import pymysql
|
||||
import subprocess
|
||||
from flask import Flask, session, request, redirect, render_template, jsonify
|
||||
from flask import Flask, session, request, redirect
|
||||
from apscheduler.schedulers.background import BackgroundScheduler
|
||||
from pytz import utc
|
||||
from datetime import datetime, timedelta
|
||||
import time
|
||||
from utils.logger import app_logger as logging
|
||||
from utils.db_manager import DatabaseManager
|
||||
import secrets
|
||||
from dotenv import load_dotenv
|
||||
from functools import wraps
|
||||
import bleach
|
||||
from utils.logger import app_logger as logging
|
||||
from utils.db_pool import DatabasePool
|
||||
from utils.error_handlers import register_error_handlers
|
||||
from middleware.security import set_secure_headers, log_request_info, require_https
|
||||
|
||||
# 加载环境变量
|
||||
load_dotenv()
|
||||
@@ -56,21 +54,6 @@ def get_db_connection_interactive():
|
||||
logging.error(f"数据库连接失败: {e}")
|
||||
raise
|
||||
|
||||
def sanitize_input(text):
|
||||
"""清理用户输入,防止XSS攻击"""
|
||||
if text is None:
|
||||
return None
|
||||
return bleach.clean(str(text), strip=True)
|
||||
|
||||
def set_secure_headers(response):
|
||||
"""设置安全响应头"""
|
||||
response.headers['X-Content-Type-Options'] = 'nosniff'
|
||||
response.headers['X-Frame-Options'] = 'SAMEORIGIN'
|
||||
response.headers['X-XSS-Protection'] = '1; mode=block'
|
||||
response.headers['Strict-Transport-Security'] = 'max-age=31536000; includeSubDomains'
|
||||
response.headers['Content-Security-Policy'] = "default-src 'self'"
|
||||
return response
|
||||
|
||||
# 初始化 Flask 应用
|
||||
app = Flask(__name__)
|
||||
app.secret_key = os.getenv('FLASK_SECRET_KEY', secrets.token_hex(32))
|
||||
@@ -87,11 +70,15 @@ from views.workflow_api import workflow_bp, workflow_api_bp
|
||||
app.register_blueprint(page.pb)
|
||||
app.register_blueprint(user.ub)
|
||||
app.register_blueprint(spider_bp)
|
||||
app.register_blueprint(workflow_bp) # 注册工作流蓝图
|
||||
app.register_blueprint(workflow_api_bp) # 注册工作流API蓝图
|
||||
app.register_blueprint(workflow_bp)
|
||||
app.register_blueprint(workflow_api_bp)
|
||||
|
||||
# 注册错误处理器
|
||||
register_error_handlers(app)
|
||||
|
||||
# 首页路由
|
||||
@app.route('/')
|
||||
@require_https()
|
||||
def hello_world():
|
||||
session.clear()
|
||||
return redirect('/user/login')
|
||||
@@ -99,11 +86,9 @@ def hello_world():
|
||||
# 请求前中间件
|
||||
@app.before_request
|
||||
def before_request():
|
||||
# 检查是否是HTTPS
|
||||
if not request.is_secure and not app.debug:
|
||||
url = request.url.replace('http://', 'https://', 1)
|
||||
return redirect(url, code=301)
|
||||
|
||||
# 记录请求信息
|
||||
log_request_info()
|
||||
|
||||
# 如果请求的是静态文件路径,允许访问
|
||||
if request.path.startswith('/static'):
|
||||
return
|
||||
@@ -138,35 +123,6 @@ def before_request():
|
||||
def after_request(response):
|
||||
return set_secure_headers(response)
|
||||
|
||||
# 错误处理
|
||||
@app.errorhandler(404)
|
||||
def not_found_error(error):
|
||||
return render_template('404.html'), 404
|
||||
|
||||
@app.errorhandler(500)
|
||||
def internal_error(error):
|
||||
return render_template('error.html',
|
||||
error_code=500,
|
||||
error_title='服务器错误',
|
||||
error_message='服务器遇到了一个问题,请稍后再试。',
|
||||
error_i18n_key='serverError'), 500
|
||||
|
||||
@app.errorhandler(403)
|
||||
def forbidden_error(error):
|
||||
return render_template('error.html',
|
||||
error_code=403,
|
||||
error_title='禁止访问',
|
||||
error_message='您没有权限访问此页面。',
|
||||
error_i18n_key='forbidden'), 403
|
||||
|
||||
@app.errorhandler(400)
|
||||
def bad_request_error(error):
|
||||
return render_template('error.html',
|
||||
error_code=400,
|
||||
error_title='错误请求',
|
||||
error_message='服务器无法理解您的请求。',
|
||||
error_i18n_key='badRequest'), 400
|
||||
|
||||
# 数据库配置
|
||||
DB_CONFIG = {
|
||||
'host': os.getenv('DB_HOST', 'localhost'),
|
||||
@@ -178,9 +134,6 @@ DB_CONFIG = {
|
||||
'ssl': {'ca': os.getenv('DB_SSL_CA')} if os.getenv('DB_SSL_CA') else None
|
||||
}
|
||||
|
||||
# 初始化数据库管理器
|
||||
DatabaseManager.initialize(DB_CONFIG)
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 检测是否需要初始化数据库
|
||||
try:
|
||||
@@ -194,6 +147,13 @@ if __name__ == '__main__':
|
||||
logging.error(f"数据库初始化失败: {e}")
|
||||
exit(1)
|
||||
|
||||
# 初始化数据库连接池
|
||||
try:
|
||||
DatabasePool.initialize(DB_CONFIG)
|
||||
except Exception as e:
|
||||
logging.error(f"数据库连接池初始化失败: {e}")
|
||||
exit(1)
|
||||
|
||||
# 设置定时任务
|
||||
try:
|
||||
scheduler = BackgroundScheduler(timezone=utc)
|
||||
@@ -222,23 +182,9 @@ if __name__ == '__main__':
|
||||
logging.error(f"应用启动失败: {e}")
|
||||
if 'scheduler' in locals():
|
||||
scheduler.shutdown()
|
||||
DatabasePool.close()
|
||||
exit(1)
|
||||
finally:
|
||||
if 'scheduler' in locals():
|
||||
scheduler.shutdown()
|
||||
|
||||
# 请求日志记录
|
||||
@app.before_request
|
||||
def log_request_info():
|
||||
# 记录请求信息,但排除敏感数据
|
||||
sanitized_headers = dict(request.headers)
|
||||
if 'Authorization' in sanitized_headers:
|
||||
sanitized_headers['Authorization'] = '[FILTERED]'
|
||||
if 'Cookie' in sanitized_headers:
|
||||
sanitized_headers['Cookie'] = '[FILTERED]'
|
||||
|
||||
logging.info(
|
||||
f"Request: {request.method} {request.path}\n"
|
||||
f"Remote IP: {request.remote_addr}\n"
|
||||
f"Headers: {sanitized_headers}"
|
||||
)
|
||||
DatabasePool.close()
|
||||
|
||||
Reference in New Issue
Block a user