diff --git a/InsightEngine/utils/db.py b/InsightEngine/utils/db.py index 78410ef..405c4f1 100644 --- a/InsightEngine/utils/db.py +++ b/InsightEngine/utils/db.py @@ -7,7 +7,7 @@ """ from __future__ import annotations - +from urllib.parse import quote_plus import asyncio import os from typing import Any, Dict, Iterable, List, Optional, Union @@ -36,6 +36,8 @@ def _build_database_url() -> str: if os.getenv("DATABASE_URL"): return os.getenv("DATABASE_URL") # 直接使用外部提供的完整URL + password = quote_plus(password) + if dialect in ("postgresql", "postgres"): # PostgreSQL 使用 asyncpg 驱动 return f"postgresql+asyncpg://{user}:{password}@{host}:{port}/{db_name}" diff --git a/MindSpider/main.py b/MindSpider/main.py index 60d5898..bdee1d8 100644 --- a/MindSpider/main.py +++ b/MindSpider/main.py @@ -18,6 +18,7 @@ from sqlalchemy.ext.asyncio import create_async_engine, AsyncEngine from sqlalchemy import inspect, text from config import settings from loguru import logger +from urllib.parse import quote_plus # 添加项目根目录到路径 project_root = Path(__file__).parent @@ -73,10 +74,10 @@ class MindSpider: def build_async_url() -> str: dialect = (settings.DB_DIALECT or "mysql").lower() if dialect == "postgresql": - return f"postgresql+asyncpg://{settings.DB_USER}:{settings.DB_PASSWORD}@{settings.DB_HOST}:{settings.DB_PORT}/{settings.DB_NAME}" + return f"postgresql+asyncpg://{settings.DB_USER}:{quote_plus(settings.DB_PASSWORD)}@{settings.DB_HOST}:{settings.DB_PORT}/{settings.DB_NAME}" # 默认使用 mysql 异步驱动 asyncmy return ( - f"mysql+asyncmy://{settings.DB_USER}:{settings.DB_PASSWORD}" + f"mysql+asyncmy://{settings.DB_USER}:{quote_plus(settings.DB_PASSWORD)}" f"@{settings.DB_HOST}:{settings.DB_PORT}/{settings.DB_NAME}?charset={settings.DB_CHARSET}" ) @@ -104,9 +105,9 @@ class MindSpider: def build_async_url() -> str: dialect = (settings.DB_DIALECT or "mysql").lower() if dialect == "postgresql": - return f"postgresql+asyncpg://{settings.DB_USER}:{settings.DB_PASSWORD}@{settings.DB_HOST}:{settings.DB_PORT}/{settings.DB_NAME}" + return f"postgresql+asyncpg://{settings.DB_USER}:{quote_plus(settings.DB_PASSWORD)}@{settings.DB_HOST}:{settings.DB_PORT}/{settings.DB_NAME}" return ( - f"mysql+asyncmy://{settings.DB_USER}:{settings.DB_PASSWORD}" + f"mysql+asyncmy://{settings.DB_USER}:{quote_plus(settings.DB_PASSWORD)}" f"@{settings.DB_HOST}:{settings.DB_PORT}/{settings.DB_NAME}?charset={settings.DB_CHARSET}" ) diff --git a/MindSpider/schema/db_manager.py b/MindSpider/schema/db_manager.py index 568df6e..26f423f 100644 --- a/MindSpider/schema/db_manager.py +++ b/MindSpider/schema/db_manager.py @@ -13,6 +13,7 @@ import argparse from pathlib import Path from datetime import datetime, timedelta from loguru import logger +from urllib.parse import quote_plus # 添加项目根目录到路径 project_root = Path(__file__).parent.parent @@ -36,9 +37,9 @@ class DatabaseManager: try: dialect = (settings.DB_DIALECT or "mysql").lower() if dialect in ("postgresql", "postgres"): - url = f"postgresql+psycopg://{settings.DB_USER}:{settings.DB_PASSWORD}@{settings.DB_HOST}:{settings.DB_PORT}/{settings.DB_NAME}" + url = f"postgresql+psycopg://{settings.DB_USER}:{quote_plus(settings.DB_PASSWORD)}@{settings.DB_HOST}:{settings.DB_PORT}/{settings.DB_NAME}" else: - url = f"mysql+pymysql://{settings.DB_USER}:{settings.DB_PASSWORD}@{settings.DB_HOST}:{settings.DB_PORT}/{settings.DB_NAME}?charset={settings.DB_CHARSET}" + url = f"mysql+pymysql://{settings.DB_USER}:{quote_plus(settings.DB_PASSWORD)}@{settings.DB_HOST}:{settings.DB_PORT}/{settings.DB_NAME}?charset={settings.DB_CHARSET}" self.engine = create_engine(url, future=True) logger.info(f"成功连接到数据库: {settings.DB_NAME}") except Exception as e: diff --git a/MindSpider/schema/init_database.py b/MindSpider/schema/init_database.py index d561625..e227525 100644 --- a/MindSpider/schema/init_database.py +++ b/MindSpider/schema/init_database.py @@ -13,7 +13,7 @@ from __future__ import annotations import asyncio import os from typing import Optional - +from urllib.parse import quote_plus from loguru import logger from sqlalchemy.ext.asyncio import create_async_engine @@ -49,6 +49,7 @@ def _build_database_url() -> str: port = str(settings.DB_PORT or ("3306" if dialect == "mysql" else "5432")) user = settings.DB_USER or "root" password = settings.DB_PASSWORD or "" + password = quote_plus(password) db_name = settings.DB_NAME or "mindspider" if dialect in ("postgresql", "postgres"):