286 lines
10 KiB
Python
286 lines
10 KiB
Python
from flask import Flask, session, render_template, redirect, Blueprint, request
|
|
from utils.mynlp import SnowNLP
|
|
from utils.getHomePageData import *
|
|
from utils.getHotWordPageData import *
|
|
from utils.getTableData import *
|
|
from utils.getPublicData import getAllHotWords, getAllTopics, getArticleByType, getArticleById
|
|
from utils.getEchartsData import *
|
|
from utils.getTopicPageData import *
|
|
from utils.yuqingpredict import *
|
|
from utils.logger import app_logger as logging
|
|
import torch
|
|
from model_pro.MHA import MultiHeadAttentionLayer
|
|
from model_pro.classifier import FinalClassifier
|
|
from model_pro.BERT_CTM import BERT_CTM_Model
|
|
|
|
pb = Blueprint('page',
|
|
__name__,
|
|
url_prefix='/page',
|
|
template_folder='templates')
|
|
|
|
# 设置设备
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
# 加载模型(全局变量,避免重复加载)
|
|
model_save_path = 'model_pro/final_model.pt'
|
|
bert_model_path = 'model_pro/bert_model'
|
|
ctm_tokenizer_path = 'model_pro/sentence_bert_model'
|
|
|
|
try:
|
|
classifier_model = torch.load(model_save_path, map_location=device)
|
|
classifier_model.eval()
|
|
attention_model = MultiHeadAttentionLayer(embed_size=768, num_heads=8)
|
|
attention_model.to(device)
|
|
attention_model.eval()
|
|
bert_ctm_model = BERT_CTM_Model(
|
|
bert_model_path=bert_model_path,
|
|
ctm_tokenizer_path=ctm_tokenizer_path
|
|
)
|
|
except Exception as e:
|
|
print(f"模型加载失败: {e}")
|
|
|
|
def predict_sentiment(text):
|
|
"""使用改进版模型预测单个文本的情感"""
|
|
try:
|
|
# 获取文本嵌入
|
|
embeddings = bert_ctm_model.get_bert_embeddings([text])
|
|
|
|
# 转换为tensor
|
|
batch_x = torch.tensor(embeddings, dtype=torch.float32).to(device)
|
|
batch_x = torch.mean(batch_x, dim=1)
|
|
|
|
with torch.no_grad():
|
|
# 使用注意力机制
|
|
attention_output = attention_model(batch_x, batch_x, batch_x)
|
|
# 获取分类结果
|
|
outputs = classifier_model(attention_output)
|
|
outputs = torch.mean(outputs, dim=1)
|
|
# 获取预测标签和概率
|
|
probabilities = torch.softmax(outputs, dim=1)
|
|
_, predicted = torch.max(outputs, 1)
|
|
|
|
return predicted.item(), probabilities[0][predicted.item()].item()
|
|
except Exception as e:
|
|
print(f"预测过程中出现错误: {e}")
|
|
return None, None
|
|
|
|
@pb.route('/home')
|
|
def home():
|
|
username = session.get('username')
|
|
articleLenMax, likeCountMaxAuthorName, cityMax = getHomeTagsData()
|
|
commentsLikeCountTopFore = getHomeCommentsLikeCountTopFore()
|
|
X, Y = getHomeArticleCreatedAtChart()
|
|
typeChart = getHomeTypeChart()
|
|
createAtChart = getHomeCommentCreatedChart()
|
|
# getUserNameWordCloud()
|
|
return render_template('index.html',
|
|
username=username,
|
|
articleLenMax=articleLenMax,
|
|
likeCountMaxAuthorName=likeCountMaxAuthorName,
|
|
cityMax=cityMax,
|
|
commentsLikeCountTopFore=commentsLikeCountTopFore,
|
|
xData=X,
|
|
yData=Y,
|
|
typeChart=typeChart,
|
|
createAtChart=createAtChart)
|
|
|
|
|
|
@pb.route('/hotWord')
|
|
def hotWord():
|
|
username = session.get('username')
|
|
hotWordList = getAllHotWords()
|
|
print(hotWordList)
|
|
defaultHotWord = hotWordList[0][0]
|
|
if request.args.get('hotWord'):
|
|
defaultHotWord = request.args.get('hotWord')
|
|
hotWordLen = getHotWordLen(defaultHotWord)
|
|
X, Y = getHotWordPageCreatedAtCharData(defaultHotWord)
|
|
sentences = ''
|
|
value = SnowNLP(defaultHotWord).sentiments
|
|
if value == 0.5:
|
|
sentences = '中性'
|
|
elif value > 0.5:
|
|
sentences = '正面'
|
|
elif value < 0.5:
|
|
sentences = '负面'
|
|
comments = getCommentFilterData(defaultHotWord)
|
|
return render_template('hotWord.html',
|
|
username=username,
|
|
hotWordList=hotWordList,
|
|
defaultHotWord=defaultHotWord,
|
|
hotWordLen=hotWordLen,
|
|
sentences=sentences,
|
|
xData=X,
|
|
yData=Y,
|
|
comments=comments)
|
|
|
|
|
|
@pb.route('/hotTopic')
|
|
def hotTopic():
|
|
username = session.get('username')
|
|
topicList = getAllTopics()
|
|
defaultTopic = topicList[0][0]
|
|
if request.args.get('topic'):
|
|
defaultTopic = request.args.get('topic')
|
|
topicLen = getTopicLen(defaultTopic)
|
|
X, Y = getTopicPageCreatedAtCharData()
|
|
sentences = ''
|
|
|
|
# ... 这里要嵌入 topic 相关内容(热度?)来填充 sentences
|
|
|
|
comments = getCommentFilterDataTopic(defaultTopic)
|
|
return render_template('hotWord.html',
|
|
username=username,
|
|
topicList=topicList,
|
|
defaultTopic=defaultTopic,
|
|
topicLen=topicLen,
|
|
sentences=sentences,
|
|
xData=X,
|
|
yData=Y,
|
|
comments=comments)
|
|
|
|
|
|
@pb.route('/tableData')
|
|
def tableData():
|
|
username = session.get('username')
|
|
defaultFlag = False
|
|
if request.args.get('flag'): defaultFlag = True
|
|
tableData = getTableDataList(defaultFlag)
|
|
return render_template('tableData.html',
|
|
username=username,
|
|
tableData=tableData,
|
|
defaultFlag=defaultFlag)
|
|
|
|
|
|
@pb.route('/articleChar')
|
|
def articleChar():
|
|
username = session.get('username')
|
|
typeList = getTypeList()
|
|
defaultType = typeList[0]
|
|
if request.args.get('type'): defaultType = request.args.get('type')
|
|
X, Y = getArticleLikeCount(defaultType)
|
|
x1Data, y1Data = getArticleCommentsLen(defaultType)
|
|
x2Data, y2Data = getArticleRepotsLen(defaultType)
|
|
return render_template('articleChar.html',
|
|
username=username,
|
|
typeList=typeList,
|
|
defaultType=defaultType,
|
|
xData=X,
|
|
yData=Y,
|
|
x1Data=x1Data,
|
|
y1Data=y1Data,
|
|
x2Data=x2Data,
|
|
y2Data=y2Data)
|
|
|
|
|
|
@pb.route('/ipChar')
|
|
def ipChar():
|
|
username = session.get('username')
|
|
articleRegionData = getIPByArticleRegion()
|
|
commentRegionData = getIPByCommentsRegion()
|
|
return render_template('ipChar.html',
|
|
username=username,
|
|
articleRegionData=articleRegionData,
|
|
commentRegionData=commentRegionData)
|
|
|
|
|
|
@pb.route('/commentChar')
|
|
def commentChar():
|
|
username = session.get('username')
|
|
X, Y = getCommentDataOne()
|
|
genderPieData = getCommentDataTwo()
|
|
return render_template('commentChar.html',
|
|
username=username,
|
|
xData=X,
|
|
yData=Y,
|
|
genderPieData=genderPieData)
|
|
|
|
|
|
@pb.route('/yuqingChar')
|
|
def yuqingChar():
|
|
username = session.get('username')
|
|
X, Y, biedata = getYuQingCharDataOne()
|
|
biedata1, biedata2 = getYuQingCharDataTwo()
|
|
x1Data, y1Data = getYuQingCharDataThree()
|
|
return render_template('yuqingChar.html',
|
|
username=username,
|
|
xData=X,
|
|
yData=Y,
|
|
biedata=biedata,
|
|
biedata1=biedata1,
|
|
biedata2=biedata2,
|
|
x1Data=x1Data,
|
|
y1Data=y1Data)
|
|
|
|
@pb.route('/yuqingpredict')
|
|
def yuqingpredict():
|
|
username = session.get('username')
|
|
TopicList = getAllTopicData()
|
|
defaultTopic = TopicList[0][0]
|
|
if request.args.get('Topic'):
|
|
defaultTopic = request.args.get('Topic')
|
|
TopicLen = getTopicLen(defaultTopic)
|
|
X, Y = getTopicCreatedAtandpredictData(defaultTopic)
|
|
|
|
# 使用改进版模型进行情感预测
|
|
predicted_label, confidence = predict_sentiment(defaultTopic)
|
|
if predicted_label is not None:
|
|
sentences = '良好' if predicted_label == 0 else '不良'
|
|
sentences = f"{sentences} (置信度: {confidence:.2f})"
|
|
else:
|
|
sentences = '预测失败'
|
|
|
|
comments = getCommentFilterDataTopic(defaultTopic)
|
|
return render_template('yuqingpredict.html',
|
|
username=username,
|
|
hotWordList=TopicList,
|
|
defaultHotWord=defaultTopic,
|
|
hotWordLen=TopicLen,
|
|
sentences=sentences,
|
|
xData=X,
|
|
yData=Y,
|
|
comments=comments)
|
|
|
|
|
|
@pb.route('/articleCloud')
|
|
def articleCloud():
|
|
username = session.get('username')
|
|
return render_template('articleContentCloud.html', username=username)
|
|
|
|
|
|
@pb.route('/page/index')
|
|
def index():
|
|
"""首页路由"""
|
|
try:
|
|
hotWordList = getAllHotWords()
|
|
logging.info("成功获取热词列表")
|
|
return render_template('index.html', hotWordList=hotWordList)
|
|
except Exception as e:
|
|
logging.error(f"渲染首页时发生错误: {e}")
|
|
return render_template('error.html', error_message="加载首页失败")
|
|
|
|
@pb.route('/page/article/<type>')
|
|
def article(type):
|
|
"""文章列表页路由"""
|
|
try:
|
|
articleList = getArticleByType(type)
|
|
logging.info(f"成功获取类型为 {type} 的文章列表")
|
|
return render_template('article.html', articleList=articleList)
|
|
except Exception as e:
|
|
logging.error(f"获取文章列表时发生错误: {e}")
|
|
return render_template('error.html', error_message="加载文章列表失败")
|
|
|
|
@pb.route('/page/articleChar/<id>')
|
|
def articleChar(id):
|
|
"""文章详情页路由"""
|
|
try:
|
|
article = getArticleById(id)
|
|
if not article:
|
|
logging.warning(f"未找到ID为 {id} 的文章")
|
|
return render_template('error.html', error_message="文章不存在")
|
|
logging.info(f"成功获取ID为 {id} 的文章详情")
|
|
return render_template('articleChar.html', article=article)
|
|
except Exception as e:
|
|
logging.error(f"获取文章详情时发生错误: {e}")
|
|
return render_template('error.html', error_message="加载文章详情失败")
|