Optimize sentiment analysis module fault tolerance and keyword optimizer prompts.
This commit is contained in:
@@ -16,7 +16,6 @@ project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(_
|
||||
weibo_sentiment_path = os.path.join(project_root, "SentimentAnalysisModel", "WeiboMultilingualSentiment")
|
||||
sys.path.append(weibo_sentiment_path)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SentimentResult:
|
||||
"""情感分析结果数据类"""
|
||||
@@ -26,6 +25,7 @@ class SentimentResult:
|
||||
probability_distribution: Dict[str, float]
|
||||
success: bool = True
|
||||
error_message: Optional[str] = None
|
||||
analysis_performed: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -36,6 +36,7 @@ class BatchSentimentResult:
|
||||
success_count: int
|
||||
failed_count: int
|
||||
average_confidence: float
|
||||
analysis_performed: bool = True
|
||||
|
||||
|
||||
class WeiboMultilingualSentimentAnalyzer:
|
||||
@@ -50,6 +51,7 @@ class WeiboMultilingualSentimentAnalyzer:
|
||||
self.tokenizer = None
|
||||
self.device = None
|
||||
self.is_initialized = False
|
||||
self.is_disabled = False
|
||||
|
||||
# 情感标签映射(5级分类)
|
||||
self.sentiment_map = {
|
||||
@@ -69,6 +71,10 @@ class WeiboMultilingualSentimentAnalyzer:
|
||||
Returns:
|
||||
是否初始化成功
|
||||
"""
|
||||
if self.is_disabled:
|
||||
print("情感分析功能已禁用,跳过模型加载")
|
||||
return False
|
||||
|
||||
if self.is_initialized:
|
||||
print("模型已经初始化,无需重复加载")
|
||||
return True
|
||||
@@ -102,6 +108,7 @@ class WeiboMultilingualSentimentAnalyzer:
|
||||
self.model.to(self.device)
|
||||
self.model.eval()
|
||||
self.is_initialized = True
|
||||
self.is_disabled = False
|
||||
|
||||
print(f"模型加载成功! 使用设备: {self.device}")
|
||||
print("支持语言: 中文、英文、西班牙文、阿拉伯文、日文、韩文等22种语言")
|
||||
@@ -113,6 +120,11 @@ class WeiboMultilingualSentimentAnalyzer:
|
||||
print(f"模型加载失败: {e}")
|
||||
print("请检查网络连接或模型文件")
|
||||
self.is_initialized = False
|
||||
self.is_disabled = True
|
||||
self.model = None
|
||||
self.tokenizer = None
|
||||
self.device = None
|
||||
print("情感分析功能已禁用,将直接返回原始文本内容")
|
||||
return False
|
||||
|
||||
def _preprocess_text(self, text: str) -> str:
|
||||
@@ -144,6 +156,17 @@ class WeiboMultilingualSentimentAnalyzer:
|
||||
Returns:
|
||||
SentimentResult对象
|
||||
"""
|
||||
if self.is_disabled:
|
||||
return SentimentResult(
|
||||
text=text,
|
||||
sentiment_label="情感分析未执行",
|
||||
confidence=0.0,
|
||||
probability_distribution={},
|
||||
success=False,
|
||||
error_message="情感分析功能已禁用",
|
||||
analysis_performed=False
|
||||
)
|
||||
|
||||
if not self.is_initialized:
|
||||
return SentimentResult(
|
||||
text=text,
|
||||
@@ -151,13 +174,14 @@ class WeiboMultilingualSentimentAnalyzer:
|
||||
confidence=0.0,
|
||||
probability_distribution={},
|
||||
success=False,
|
||||
error_message="模型未初始化,请先调用 initialize() 方法"
|
||||
error_message="模型未初始化,请先调用initialize() 方法",
|
||||
analysis_performed=False
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
# 预处理文本
|
||||
processed_text = self._preprocess_text(text)
|
||||
|
||||
|
||||
if not processed_text:
|
||||
return SentimentResult(
|
||||
text=text,
|
||||
@@ -165,9 +189,10 @@ class WeiboMultilingualSentimentAnalyzer:
|
||||
confidence=0.0,
|
||||
probability_distribution={},
|
||||
success=False,
|
||||
error_message="输入文本为空或无效"
|
||||
error_message="输入文本为空或无效内容",
|
||||
analysis_performed=False
|
||||
)
|
||||
|
||||
|
||||
# 分词编码
|
||||
inputs = self.tokenizer(
|
||||
processed_text,
|
||||
@@ -176,26 +201,26 @@ class WeiboMultilingualSentimentAnalyzer:
|
||||
truncation=True,
|
||||
return_tensors='pt'
|
||||
)
|
||||
|
||||
|
||||
# 转移到设备
|
||||
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
||||
|
||||
|
||||
# 预测
|
||||
with torch.no_grad():
|
||||
outputs = self.model(**inputs)
|
||||
logits = outputs.logits
|
||||
probabilities = torch.softmax(logits, dim=1)
|
||||
prediction = torch.argmax(probabilities, dim=1).item()
|
||||
|
||||
|
||||
# 构建结果
|
||||
confidence = probabilities[0][prediction].item()
|
||||
label = self.sentiment_map[prediction]
|
||||
|
||||
|
||||
# 构建概率分布字典
|
||||
prob_dist = {}
|
||||
for i, (label_name, prob) in enumerate(zip(self.sentiment_map.values(), probabilities[0])):
|
||||
for label_name, prob in zip(self.sentiment_map.values(), probabilities[0]):
|
||||
prob_dist[label_name] = prob.item()
|
||||
|
||||
|
||||
return SentimentResult(
|
||||
text=text,
|
||||
sentiment_label=label,
|
||||
@@ -203,7 +228,7 @@ class WeiboMultilingualSentimentAnalyzer:
|
||||
probability_distribution=prob_dist,
|
||||
success=True
|
||||
)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
return SentimentResult(
|
||||
text=text,
|
||||
@@ -211,9 +236,10 @@ class WeiboMultilingualSentimentAnalyzer:
|
||||
confidence=0.0,
|
||||
probability_distribution={},
|
||||
success=False,
|
||||
error_message=f"预测时发生错误: {str(e)}"
|
||||
error_message=f"预测时发生错误: {str(e)}",
|
||||
analysis_performed=False
|
||||
)
|
||||
|
||||
|
||||
def analyze_batch(self, texts: List[str], show_progress: bool = True) -> BatchSentimentResult:
|
||||
"""
|
||||
批量情感分析
|
||||
@@ -231,7 +257,30 @@ class WeiboMultilingualSentimentAnalyzer:
|
||||
total_processed=0,
|
||||
success_count=0,
|
||||
failed_count=0,
|
||||
average_confidence=0.0
|
||||
average_confidence=0.0,
|
||||
analysis_performed=not self.is_disabled and self.is_initialized
|
||||
)
|
||||
|
||||
if self.is_disabled or not self.is_initialized:
|
||||
passthrough_results = [
|
||||
SentimentResult(
|
||||
text=text,
|
||||
sentiment_label="情感分析未执行",
|
||||
confidence=0.0,
|
||||
probability_distribution={},
|
||||
success=False,
|
||||
error_message="情感分析功能不可用",
|
||||
analysis_performed=False
|
||||
)
|
||||
for text in texts
|
||||
]
|
||||
return BatchSentimentResult(
|
||||
results=passthrough_results,
|
||||
total_processed=len(texts),
|
||||
success_count=0,
|
||||
failed_count=len(texts),
|
||||
average_confidence=0.0,
|
||||
analysis_performed=False
|
||||
)
|
||||
|
||||
results = []
|
||||
@@ -257,9 +306,46 @@ class WeiboMultilingualSentimentAnalyzer:
|
||||
total_processed=len(texts),
|
||||
success_count=success_count,
|
||||
failed_count=failed_count,
|
||||
average_confidence=average_confidence
|
||||
average_confidence=average_confidence,
|
||||
analysis_performed=True
|
||||
)
|
||||
|
||||
def _build_passthrough_analysis(
|
||||
self,
|
||||
original_data: List[Dict[str, Any]],
|
||||
reason: str,
|
||||
texts: Optional[List[str]] = None,
|
||||
results: Optional[List[SentimentResult]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
构建在情感分析不可用时的透传结�?
|
||||
"""
|
||||
total_items = len(texts) if texts is not None else len(original_data)
|
||||
response: Dict[str, Any] = {
|
||||
"sentiment_analysis": {
|
||||
"available": False,
|
||||
"reason": reason,
|
||||
"total_analyzed": 0,
|
||||
"success_rate": f"0/{total_items}",
|
||||
"average_confidence": 0.0,
|
||||
"sentiment_distribution": {},
|
||||
"high_confidence_results": [],
|
||||
"summary": f"情感分析未执行:{reason}",
|
||||
"original_texts": original_data
|
||||
}
|
||||
}
|
||||
|
||||
if texts is not None:
|
||||
response["sentiment_analysis"]["passthrough_texts"] = texts
|
||||
|
||||
if results is not None:
|
||||
response["sentiment_analysis"]["results"] = [
|
||||
result.__dict__ if isinstance(result, SentimentResult) else result
|
||||
for result in results
|
||||
]
|
||||
|
||||
return response
|
||||
|
||||
def analyze_query_results(self, query_results: List[Dict[str, Any]],
|
||||
text_field: str = "content",
|
||||
min_confidence: float = 0.5) -> Dict[str, Any]:
|
||||
@@ -311,10 +397,30 @@ class WeiboMultilingualSentimentAnalyzer:
|
||||
}
|
||||
}
|
||||
|
||||
if self.is_disabled:
|
||||
return self._build_passthrough_analysis(
|
||||
original_data=original_data,
|
||||
reason="情感分析模型不可用",
|
||||
texts=texts_to_analyze
|
||||
)
|
||||
|
||||
# 执行批量情感分析
|
||||
print(f"正在对{len(texts_to_analyze)}条内容进行情感分析...")
|
||||
batch_result = self.analyze_batch(texts_to_analyze, show_progress=True)
|
||||
|
||||
if not batch_result.analysis_performed:
|
||||
reason = "情感分析功能不可用"
|
||||
if batch_result.results:
|
||||
candidate_error = next((r.error_message for r in batch_result.results if r.error_message), None)
|
||||
if candidate_error:
|
||||
reason = candidate_error
|
||||
return self._build_passthrough_analysis(
|
||||
original_data=original_data,
|
||||
reason=reason,
|
||||
texts=texts_to_analyze,
|
||||
results=batch_result.results
|
||||
)
|
||||
|
||||
# 统计情感分布
|
||||
sentiment_distribution = {}
|
||||
high_confidence_results = []
|
||||
@@ -392,31 +498,18 @@ def analyze_sentiment(text_or_texts: Union[str, List[str]],
|
||||
Returns:
|
||||
SentimentResult或BatchSentimentResult
|
||||
"""
|
||||
if initialize_if_needed and not multilingual_sentiment_analyzer.is_initialized:
|
||||
if not multilingual_sentiment_analyzer.initialize():
|
||||
# 如果初始化失败,返回失败结果
|
||||
if isinstance(text_or_texts, str):
|
||||
return SentimentResult(
|
||||
text=text_or_texts,
|
||||
sentiment_label="初始化失败",
|
||||
confidence=0.0,
|
||||
probability_distribution={},
|
||||
success=False,
|
||||
error_message="模型初始化失败"
|
||||
)
|
||||
else:
|
||||
return BatchSentimentResult(
|
||||
results=[],
|
||||
total_processed=0,
|
||||
success_count=0,
|
||||
failed_count=len(text_or_texts),
|
||||
average_confidence=0.0
|
||||
)
|
||||
if (
|
||||
initialize_if_needed
|
||||
and not multilingual_sentiment_analyzer.is_initialized
|
||||
and not multilingual_sentiment_analyzer.is_disabled
|
||||
):
|
||||
multilingual_sentiment_analyzer.initialize()
|
||||
|
||||
if isinstance(text_or_texts, str):
|
||||
return multilingual_sentiment_analyzer.analyze_single_text(text_or_texts)
|
||||
else:
|
||||
return multilingual_sentiment_analyzer.analyze_batch(text_or_texts)
|
||||
texts_list = list(text_or_texts)
|
||||
return multilingual_sentiment_analyzer.analyze_batch(texts_list)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user