Completed requirements.txt, fixed the Dockerfile, and updated the README. Significantly refactored the sentiment analyzer to be more robust against missing machine learning dependencies and controllable via a toggle.
This commit is contained in:
@@ -3,14 +3,39 @@
|
||||
基于WeiboMultilingualSentiment模型为InsightEngine提供情感分析功能
|
||||
"""
|
||||
|
||||
import torch
|
||||
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
||||
import os
|
||||
import sys
|
||||
from typing import List, Dict, Any, Optional, Union
|
||||
from dataclasses import dataclass
|
||||
import re
|
||||
|
||||
try:
|
||||
import torch
|
||||
TORCH_AVAILABLE = True
|
||||
except ImportError:
|
||||
torch = None # type: ignore
|
||||
TORCH_AVAILABLE = False
|
||||
|
||||
try:
|
||||
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
||||
TRANSFORMERS_AVAILABLE = True
|
||||
except ImportError:
|
||||
AutoTokenizer = None # type: ignore
|
||||
AutoModelForSequenceClassification = None # type: ignore
|
||||
TRANSFORMERS_AVAILABLE = False
|
||||
|
||||
|
||||
# INFO:若想跳过情感分析,可手动切换此开关为False
|
||||
SENTIMENT_ANALYSIS_ENABLED = True
|
||||
|
||||
def _describe_missing_dependencies() -> str:
|
||||
missing = []
|
||||
if not TORCH_AVAILABLE:
|
||||
missing.append("PyTorch")
|
||||
if not TRANSFORMERS_AVAILABLE:
|
||||
missing.append("Transformers")
|
||||
return " / ".join(missing)
|
||||
|
||||
# 添加项目根目录到路径,以便导入WeiboMultilingualSentiment
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
weibo_sentiment_path = os.path.join(project_root, "SentimentAnalysisModel", "WeiboMultilingualSentiment")
|
||||
@@ -52,6 +77,7 @@ class WeiboMultilingualSentimentAnalyzer:
|
||||
self.device = None
|
||||
self.is_initialized = False
|
||||
self.is_disabled = False
|
||||
self.disable_reason: Optional[str] = None
|
||||
|
||||
# 情感标签映射(5级分类)
|
||||
self.sentiment_map = {
|
||||
@@ -61,8 +87,52 @@ class WeiboMultilingualSentimentAnalyzer:
|
||||
3: "正面",
|
||||
4: "非常正面"
|
||||
}
|
||||
|
||||
print("WeiboMultilingualSentimentAnalyzer 已创建,调用 initialize() 来加载模型")
|
||||
|
||||
if not SENTIMENT_ANALYSIS_ENABLED:
|
||||
self.disable("情感分析功能已在配置中关闭。")
|
||||
elif not (TORCH_AVAILABLE and TRANSFORMERS_AVAILABLE):
|
||||
missing = _describe_missing_dependencies() or "未知依赖"
|
||||
self.disable(f"缺少依赖: {missing},情感分析已禁用。")
|
||||
|
||||
if self.is_disabled:
|
||||
reason = self.disable_reason or "Sentiment analysis disabled."
|
||||
print(f"WeiboMultilingualSentimentAnalyzer initialized but disabled: {reason}")
|
||||
else:
|
||||
print("WeiboMultilingualSentimentAnalyzer 已创建,调用 initialize() 来加载模型")
|
||||
|
||||
def disable(self, reason: Optional[str] = None, drop_state: bool = False) -> None:
|
||||
"""Disable sentiment analysis, optionally clearing loaded resources."""
|
||||
self.is_disabled = True
|
||||
self.disable_reason = reason or "Sentiment analysis disabled."
|
||||
if drop_state:
|
||||
self.model = None
|
||||
self.tokenizer = None
|
||||
self.device = None
|
||||
self.is_initialized = False
|
||||
|
||||
def enable(self) -> bool:
|
||||
"""Attempt to enable sentiment analysis; returns True if enabled."""
|
||||
if not SENTIMENT_ANALYSIS_ENABLED:
|
||||
self.disable("情感分析功能已在配置中关闭。")
|
||||
return False
|
||||
if not (TORCH_AVAILABLE and TRANSFORMERS_AVAILABLE):
|
||||
missing = _describe_missing_dependencies() or "未知依赖"
|
||||
self.disable(f"缺少依赖: {missing},情感分析已禁用。")
|
||||
return False
|
||||
self.is_disabled = False
|
||||
self.disable_reason = None
|
||||
return True
|
||||
|
||||
def _select_device(self):
|
||||
"""Select the best available torch device."""
|
||||
if not TORCH_AVAILABLE:
|
||||
return None
|
||||
if torch.cuda.is_available():
|
||||
return torch.device("cuda")
|
||||
mps_backend = getattr(torch.backends, "mps", None)
|
||||
if mps_backend and getattr(mps_backend, "is_available", lambda: False)() and getattr(mps_backend, "is_built", lambda: False)():
|
||||
return torch.device("mps")
|
||||
return torch.device("cpu")
|
||||
|
||||
def initialize(self) -> bool:
|
||||
"""
|
||||
@@ -72,7 +142,14 @@ class WeiboMultilingualSentimentAnalyzer:
|
||||
是否初始化成功
|
||||
"""
|
||||
if self.is_disabled:
|
||||
print("情感分析功能已禁用,跳过模型加载")
|
||||
reason = self.disable_reason or "情感分析功能已禁用"
|
||||
print(f"情感分析功能已禁用,跳过模型加载:{reason}")
|
||||
return False
|
||||
|
||||
if not (TORCH_AVAILABLE and TRANSFORMERS_AVAILABLE):
|
||||
missing = _describe_missing_dependencies() or "未知依赖"
|
||||
self.disable(f"缺少依赖: {missing},情感分析已禁用。", drop_state=True)
|
||||
print(f"缺少依赖: {missing},无法加载情感分析模型。")
|
||||
return False
|
||||
|
||||
if self.is_initialized:
|
||||
@@ -104,11 +181,23 @@ class WeiboMultilingualSentimentAnalyzer:
|
||||
print(f"模型已保存到: {local_model_path}")
|
||||
|
||||
# 设置设备
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
device = self._select_device()
|
||||
if device is None:
|
||||
raise RuntimeError("未检测到可用的计算设备")
|
||||
|
||||
self.device = device
|
||||
self.model.to(self.device)
|
||||
self.model.eval()
|
||||
self.is_initialized = True
|
||||
self.is_disabled = False
|
||||
self.enable()
|
||||
|
||||
device_type = getattr(self.device, "type", str(self.device))
|
||||
if device_type == "cuda":
|
||||
print("检测到可用 GPU,已优先使用 CUDA 进行推理。")
|
||||
elif device_type == "mps":
|
||||
print("检测到 Apple MPS 设备,已使用 MPS 进行推理。")
|
||||
else:
|
||||
print("未检测到 GPU,自动使用 CPU 进行推理。")
|
||||
|
||||
print(f"模型加载成功! 使用设备: {self.device}")
|
||||
print("支持语言: 中文、英文、西班牙文、阿拉伯文、日文、韩文等22种语言")
|
||||
@@ -117,14 +206,10 @@ class WeiboMultilingualSentimentAnalyzer:
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"模型加载失败: {e}")
|
||||
error_message = f"模型加载失败: {e}"
|
||||
print(error_message)
|
||||
print("请检查网络连接或模型文件")
|
||||
self.is_initialized = False
|
||||
self.is_disabled = True
|
||||
self.model = None
|
||||
self.tokenizer = None
|
||||
self.device = None
|
||||
print("情感分析功能已禁用,将直接返回原始文本内容")
|
||||
self.disable(error_message, drop_state=True)
|
||||
return False
|
||||
|
||||
def _preprocess_text(self, text: str) -> str:
|
||||
@@ -163,7 +248,7 @@ class WeiboMultilingualSentimentAnalyzer:
|
||||
confidence=0.0,
|
||||
probability_distribution={},
|
||||
success=False,
|
||||
error_message="情感分析功能已禁用",
|
||||
error_message=self.disable_reason or "情感分析功能已禁用",
|
||||
analysis_performed=False
|
||||
)
|
||||
|
||||
@@ -269,7 +354,7 @@ class WeiboMultilingualSentimentAnalyzer:
|
||||
confidence=0.0,
|
||||
probability_distribution={},
|
||||
success=False,
|
||||
error_message="情感分析功能不可用",
|
||||
error_message=self.disable_reason or "情感分析功能不可用",
|
||||
analysis_performed=False
|
||||
)
|
||||
for text in texts
|
||||
@@ -318,7 +403,7 @@ class WeiboMultilingualSentimentAnalyzer:
|
||||
results: Optional[List[SentimentResult]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
构建在情感分析不可用时的透传结�?
|
||||
构建在情感分析不可用时的透传结果
|
||||
"""
|
||||
total_items = len(texts) if texts is not None else len(original_data)
|
||||
response: Dict[str, Any] = {
|
||||
@@ -400,7 +485,7 @@ class WeiboMultilingualSentimentAnalyzer:
|
||||
if self.is_disabled:
|
||||
return self._build_passthrough_analysis(
|
||||
original_data=original_data,
|
||||
reason="情感分析模型不可用",
|
||||
reason=self.disable_reason or "情感分析模型不可用",
|
||||
texts=texts_to_analyze
|
||||
)
|
||||
|
||||
@@ -409,7 +494,7 @@ class WeiboMultilingualSentimentAnalyzer:
|
||||
batch_result = self.analyze_batch(texts_to_analyze, show_progress=True)
|
||||
|
||||
if not batch_result.analysis_performed:
|
||||
reason = "情感分析功能不可用"
|
||||
reason = self.disable_reason or "情感分析功能不可用"
|
||||
if batch_result.results:
|
||||
candidate_error = next((r.error_message for r in batch_result.results if r.error_message), None)
|
||||
if candidate_error:
|
||||
@@ -486,6 +571,16 @@ class WeiboMultilingualSentimentAnalyzer:
|
||||
multilingual_sentiment_analyzer = WeiboMultilingualSentimentAnalyzer()
|
||||
|
||||
|
||||
def enable_sentiment_analysis() -> bool:
|
||||
"""Public helper to enable sentiment analysis at runtime."""
|
||||
return multilingual_sentiment_analyzer.enable()
|
||||
|
||||
|
||||
def disable_sentiment_analysis(reason: Optional[str] = None, drop_state: bool = False) -> None:
|
||||
"""Public helper to disable sentiment analysis at runtime."""
|
||||
multilingual_sentiment_analyzer.disable(reason=reason, drop_state=drop_state)
|
||||
|
||||
|
||||
def analyze_sentiment(text_or_texts: Union[str, List[str]],
|
||||
initialize_if_needed: bool = True) -> Union[SentimentResult, BatchSentimentResult]:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user