194 lines
6.3 KiB
Python
194 lines
6.3 KiB
Python
import os
|
|
import getpass
|
|
import pymysql
|
|
import subprocess
|
|
from flask import Flask, session, request, redirect
|
|
from apscheduler.schedulers.background import BackgroundScheduler
|
|
try:
|
|
from zoneinfo import ZoneInfo # Python 3.9+
|
|
except ImportError:
|
|
from backports.zoneinfo import ZoneInfo # Python < 3.9
|
|
from datetime import datetime, timedelta
|
|
import secrets
|
|
from dotenv import load_dotenv
|
|
from utils.logger import app_logger as logging
|
|
from utils.db_pool import DatabasePool
|
|
from utils.error_handlers import register_error_handlers
|
|
from middleware.security import set_secure_headers, log_request_info, require_https
|
|
|
|
# 加载环境变量
|
|
load_dotenv()
|
|
|
|
def get_db_connection_interactive():
|
|
"""
|
|
通过终端交互获取数据库连接参数,若按回车则使用默认值。
|
|
返回一个连接对象。
|
|
"""
|
|
print("请依次输入数据库连接信息(直接按回车使用默认值):")
|
|
|
|
host = input(" 1. 主机 (默认: localhost): ") or os.getenv('DB_HOST', 'localhost')
|
|
port_str = input(" 2. 端口 (默认: 3306): ") or os.getenv('DB_PORT', '3306')
|
|
try:
|
|
port = int(port_str)
|
|
except ValueError:
|
|
logging.warning("端口号无效,使用默认端口 3306。")
|
|
port = 3306
|
|
|
|
user = input(" 3. 用户名 (默认: root): ") or os.getenv('DB_USER', 'root')
|
|
password = getpass.getpass(" 4. 密码: ") or os.getenv('DB_PASSWORD', '')
|
|
db_name = input(" 5. 数据库名 (默认: Weibo_PublicOpinion_AnalysisSystem): ") or os.getenv('DB_NAME', 'Weibo_PublicOpinion_AnalysisSystem')
|
|
|
|
logging.info(f"尝试连接到数据库: {user}@{host}:{port}/{db_name}")
|
|
|
|
try:
|
|
connection = pymysql.connect(
|
|
host=host,
|
|
port=port,
|
|
user=user,
|
|
password=password,
|
|
database=db_name,
|
|
charset='utf8mb4',
|
|
cursorclass=pymysql.cursors.DictCursor,
|
|
ssl={'ssl': {'ca': os.getenv('DB_SSL_CA')}} if os.getenv('DB_SSL_CA') else None
|
|
)
|
|
logging.info("数据库连接成功。")
|
|
return connection
|
|
except pymysql.MySQLError as e:
|
|
logging.error(f"数据库连接失败: {e}")
|
|
raise
|
|
|
|
# 初始化 Flask 应用
|
|
app = Flask(__name__)
|
|
app.secret_key = os.getenv('FLASK_SECRET_KEY', secrets.token_hex(32))
|
|
app.config['SESSION_COOKIE_SECURE'] = True
|
|
app.config['SESSION_COOKIE_HTTPONLY'] = True
|
|
app.config['SESSION_COOKIE_SAMESITE'] = 'Lax'
|
|
app.config['PERMANENT_SESSION_LIFETIME'] = timedelta(hours=2)
|
|
|
|
# 导入蓝图
|
|
from views.page import page
|
|
from views.user import user
|
|
from views.spider_control import spider_bp
|
|
from views.workflow_api import workflow_bp, workflow_api_bp
|
|
app.register_blueprint(page.pb)
|
|
app.register_blueprint(user.ub)
|
|
app.register_blueprint(spider_bp)
|
|
app.register_blueprint(workflow_bp)
|
|
app.register_blueprint(workflow_api_bp)
|
|
|
|
# 注册错误处理器
|
|
register_error_handlers(app)
|
|
|
|
# 首页路由
|
|
@app.route('/')
|
|
@require_https()
|
|
def hello_world():
|
|
session.clear()
|
|
return redirect('/user/login')
|
|
|
|
# 请求前中间件
|
|
@app.before_request
|
|
def before_request():
|
|
# 记录请求信息
|
|
log_request_info()
|
|
|
|
# 如果请求的是静态文件路径,允许访问
|
|
if request.path.startswith('/static'):
|
|
return
|
|
|
|
# 如果请求的是登录或注册页面,不需要会话验证
|
|
if request.path in ['/user/login', '/user/register']:
|
|
return
|
|
|
|
# 验证会话
|
|
if not session.get('username'):
|
|
return redirect('/user/login')
|
|
|
|
# 验证会话完整性
|
|
if 'client_info' not in session:
|
|
session.clear()
|
|
return redirect('/user/login')
|
|
|
|
# 验证客户端信息
|
|
current_client = {
|
|
'ip': request.remote_addr,
|
|
'user_agent': str(request.user_agent)
|
|
}
|
|
stored_client = session.get('client_info', {})
|
|
|
|
if (current_client['ip'] != stored_client.get('ip') or
|
|
current_client['user_agent'] != stored_client.get('user_agent')):
|
|
session.clear()
|
|
return redirect('/user/login')
|
|
|
|
# 响应后中间件
|
|
@app.after_request
|
|
def after_request(response):
|
|
return set_secure_headers(response)
|
|
|
|
# 数据库配置
|
|
DB_CONFIG = {
|
|
'host': os.getenv('DB_HOST', 'localhost'),
|
|
'user': os.getenv('DB_USER', 'root'),
|
|
'password': os.getenv('DB_PASSWORD', ''),
|
|
'database': os.getenv('DB_NAME', 'Weibo_PublicOpinion_AnalysisSystem'),
|
|
'port': int(os.getenv('DB_PORT', '3306')),
|
|
'charset': 'utf8mb4',
|
|
'ssl': {'ca': os.getenv('DB_SSL_CA')} if os.getenv('DB_SSL_CA') else None
|
|
}
|
|
|
|
if __name__ == '__main__':
|
|
# 检测是否需要初始化数据库
|
|
try:
|
|
if os.getenv('INITIALIZE_DB', 'false').lower() == 'true':
|
|
connection = get_db_connection_interactive()
|
|
sql_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'createTables.sql')
|
|
initialize_database(connection, sql_file)
|
|
connection.close()
|
|
logging.info("数据库初始化完成。")
|
|
except Exception as e:
|
|
logging.error(f"数据库初始化失败: {e}")
|
|
exit(1)
|
|
|
|
# 初始化数据库连接池
|
|
try:
|
|
DatabasePool.initialize(DB_CONFIG)
|
|
except Exception as e:
|
|
logging.error(f"数据库连接池初始化失败: {e}")
|
|
exit(1)
|
|
|
|
# 设置定时任务
|
|
try:
|
|
scheduler = BackgroundScheduler(timezone=ZoneInfo("UTC"))
|
|
scheduler.start()
|
|
|
|
if check_database_empty():
|
|
logging.info("数据库为空。立即开始初始爬取。")
|
|
dynamic_crawl()
|
|
else:
|
|
logging.info("数据库已有数据。安排首次爬取。")
|
|
base_interval = int(os.getenv('CRAWL_INTERVAL', '18000')) # 默认5小时
|
|
scheduler.add_job(
|
|
dynamic_crawl,
|
|
'date',
|
|
run_date=datetime.now() + timedelta(seconds=base_interval),
|
|
id='dynamic_crawl'
|
|
)
|
|
|
|
# 启动应用
|
|
app.run(
|
|
host=os.getenv('FLASK_HOST', '127.0.0.1'),
|
|
port=int(os.getenv('FLASK_PORT', '5000')),
|
|
ssl_context='adhoc' if os.getenv('ENABLE_HTTPS', 'false').lower() == 'true' else None
|
|
)
|
|
except Exception as e:
|
|
logging.error(f"应用启动失败: {e}")
|
|
if 'scheduler' in locals():
|
|
scheduler.shutdown()
|
|
DatabasePool.close()
|
|
exit(1)
|
|
finally:
|
|
if 'scheduler' in locals():
|
|
scheduler.shutdown()
|
|
DatabasePool.close()
|