Added a base model class and training scripts for various sentiment analysis models, including Naive Bayes, SVM, XGBoost, LSTM, and BERT. Also, improved prediction functionality and the model loading mechanism.
This commit is contained in:
@@ -1,306 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### 加载数据集"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from utils import load_corpus, stopwords\n",
|
||||
"\n",
|
||||
"TRAIN_PATH = \"./data/weibo2018/train.txt\"\n",
|
||||
"TEST_PATH = \"./data/weibo2018/test.txt\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Building prefix dict from the default dictionary ...\n",
|
||||
"Loading model from cache /var/folders/rt/khjltk4j6n78x9x3f20hdr6m0000gp/T/jieba.cache\n",
|
||||
"Loading model cost 1.023 seconds.\n",
|
||||
"Prefix dict has been built successfully.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# 分别加载训练集和测试集\n",
|
||||
"train_data = load_corpus(TRAIN_PATH)\n",
|
||||
"test_data = load_corpus(TEST_PATH)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<div>\n",
|
||||
"<style scoped>\n",
|
||||
" .dataframe tbody tr th:only-of-type {\n",
|
||||
" vertical-align: middle;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe tbody tr th {\n",
|
||||
" vertical-align: top;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe thead th {\n",
|
||||
" text-align: right;\n",
|
||||
" }\n",
|
||||
"</style>\n",
|
||||
"<table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: right;\">\n",
|
||||
" <th></th>\n",
|
||||
" <th>words</th>\n",
|
||||
" <th>label</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>0</th>\n",
|
||||
" <td>书中 自有 黄金屋 书中 自有 颜如玉 沿着 岁月 的 长河 跋涉 或是 风光旖旎 或是 姹...</td>\n",
|
||||
" <td>1</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1</th>\n",
|
||||
" <td>这是 英超 被 黑 的 最惨 的 一次 二哈 二哈 十几年来 中国 只有 孙继海 董方卓 郑...</td>\n",
|
||||
" <td>0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>2</th>\n",
|
||||
" <td>中国 远洋 海运 集团 副总经理 俞曾 港 月 日 在 上 表示 中央 企业 走 出去 是 ...</td>\n",
|
||||
" <td>1</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>3</th>\n",
|
||||
" <td>看 流星花园 其实 也 还好 啦 现在 的 观念 以及 时尚 眼光 都 不一样 了 或许 十...</td>\n",
|
||||
" <td>1</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>4</th>\n",
|
||||
" <td>汉武帝 的 罪己 诏 的 真实性 尽管 存在 着 争议 然而 轮台 罪己 诏 作为 中国 历...</td>\n",
|
||||
" <td>1</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
"</div>"
|
||||
],
|
||||
"text/plain": [
|
||||
" words label\n",
|
||||
"0 书中 自有 黄金屋 书中 自有 颜如玉 沿着 岁月 的 长河 跋涉 或是 风光旖旎 或是 姹... 1\n",
|
||||
"1 这是 英超 被 黑 的 最惨 的 一次 二哈 二哈 十几年来 中国 只有 孙继海 董方卓 郑... 0\n",
|
||||
"2 中国 远洋 海运 集团 副总经理 俞曾 港 月 日 在 上 表示 中央 企业 走 出去 是 ... 1\n",
|
||||
"3 看 流星花园 其实 也 还好 啦 现在 的 观念 以及 时尚 眼光 都 不一样 了 或许 十... 1\n",
|
||||
"4 汉武帝 的 罪己 诏 的 真实性 尽管 存在 着 争议 然而 轮台 罪己 诏 作为 中国 历... 1"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import pandas as pd\n",
|
||||
"\n",
|
||||
"df_train = pd.DataFrame(train_data, columns=[\"words\", \"label\"])\n",
|
||||
"df_test = pd.DataFrame(test_data, columns=[\"words\", \"label\"])\n",
|
||||
"df_train.head()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### 特征编码(词袋模型)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 61,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/Users/albertdxq/opt/anaconda3/lib/python3.8/site-packages/sklearn/feature_extraction/text.py:383: UserWarning: Your stop_words may be inconsistent with your preprocessing. Tokenizing the stop words generated tokens ['元', '吨', '数', '末'] not in stop_words.\n",
|
||||
" warnings.warn('Your stop_words may be inconsistent with '\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from sklearn.feature_extraction.text import CountVectorizer\n",
|
||||
"\n",
|
||||
"vectorizer = CountVectorizer(token_pattern='\\[?\\w+\\]?', \n",
|
||||
" stop_words=stopwords)\n",
|
||||
"X_train = vectorizer.fit_transform(df_train[\"words\"])\n",
|
||||
"y_train = df_train[\"label\"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 62,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"X_test = vectorizer.transform(df_test[\"words\"])\n",
|
||||
"y_test = df_test[\"label\"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### 训练模型&测试"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 63,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"MultinomialNB()"
|
||||
]
|
||||
},
|
||||
"execution_count": 63,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from sklearn.naive_bayes import MultinomialNB\n",
|
||||
"\n",
|
||||
"clf = MultinomialNB()\n",
|
||||
"clf.fit(X_train, y_train)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 64,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# 在测试集上用模型预测结果\n",
|
||||
"y_pred = clf.predict(X_test)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 65,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" precision recall f1-score support\n",
|
||||
"\n",
|
||||
" 0 0.75 0.80 0.78 155\n",
|
||||
" 1 0.91 0.88 0.89 345\n",
|
||||
"\n",
|
||||
" accuracy 0.86 500\n",
|
||||
" macro avg 0.83 0.84 0.83 500\n",
|
||||
"weighted avg 0.86 0.86 0.86 500\n",
|
||||
"\n",
|
||||
"准确率: 0.856\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# 测试集效果检验\n",
|
||||
"from sklearn import metrics\n",
|
||||
"\n",
|
||||
"print(metrics.classification_report(y_test, y_pred))\n",
|
||||
"print(\"准确率:\", metrics.accuracy_score(y_test, y_pred))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### 手动输入句子,判断情感倾向"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 66,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from utils import processing\n",
|
||||
"\n",
|
||||
"strs = [\"终于收获一个最好消息\", \"哭了, 今天怎么这么倒霉\"]\n",
|
||||
"words = [processing(s) for s in strs]\n",
|
||||
"vec = vectorizer.transform(words)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 67,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"array([1, 0])"
|
||||
]
|
||||
},
|
||||
"execution_count": 67,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"output = clf.predict(vec)\n",
|
||||
"output"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
@@ -1,286 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### 加载数据集"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from utils import load_corpus, stopwords\n",
|
||||
"\n",
|
||||
"TRAIN_PATH = \"./data/weibo2018/train.txt\"\n",
|
||||
"TEST_PATH = \"./data/weibo2018/test.txt\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# 分别加载训练集和测试集\n",
|
||||
"train_data = load_corpus(TRAIN_PATH)\n",
|
||||
"test_data = load_corpus(TEST_PATH)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<div>\n",
|
||||
"<style scoped>\n",
|
||||
" .dataframe tbody tr th:only-of-type {\n",
|
||||
" vertical-align: middle;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe tbody tr th {\n",
|
||||
" vertical-align: top;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe thead th {\n",
|
||||
" text-align: right;\n",
|
||||
" }\n",
|
||||
"</style>\n",
|
||||
"<table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: right;\">\n",
|
||||
" <th></th>\n",
|
||||
" <th>words</th>\n",
|
||||
" <th>label</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>0</th>\n",
|
||||
" <td>书中 自有 黄金屋 书中 自有 颜如玉 沿着 岁月 的 长河 跋涉 或是 风光旖旎 或是 姹...</td>\n",
|
||||
" <td>1</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1</th>\n",
|
||||
" <td>这是 英超 被 黑 的 最惨 的 一次 二哈 二哈 十几年来 中国 只有 孙继海 董方卓 郑...</td>\n",
|
||||
" <td>0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>2</th>\n",
|
||||
" <td>中国 远洋 海运 集团 副总经理 俞曾 港 月 日 在 上 表示 中央 企业 走 出去 是 ...</td>\n",
|
||||
" <td>1</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>3</th>\n",
|
||||
" <td>看 流星花园 其实 也 还好 啦 现在 的 观念 以及 时尚 眼光 都 不一样 了 或许 十...</td>\n",
|
||||
" <td>1</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>4</th>\n",
|
||||
" <td>汉武帝 的 罪己 诏 的 真实性 尽管 存在 着 争议 然而 轮台 罪己 诏 作为 中国 历...</td>\n",
|
||||
" <td>1</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
"</div>"
|
||||
],
|
||||
"text/plain": [
|
||||
" words label\n",
|
||||
"0 书中 自有 黄金屋 书中 自有 颜如玉 沿着 岁月 的 长河 跋涉 或是 风光旖旎 或是 姹... 1\n",
|
||||
"1 这是 英超 被 黑 的 最惨 的 一次 二哈 二哈 十几年来 中国 只有 孙继海 董方卓 郑... 0\n",
|
||||
"2 中国 远洋 海运 集团 副总经理 俞曾 港 月 日 在 上 表示 中央 企业 走 出去 是 ... 1\n",
|
||||
"3 看 流星花园 其实 也 还好 啦 现在 的 观念 以及 时尚 眼光 都 不一样 了 或许 十... 1\n",
|
||||
"4 汉武帝 的 罪己 诏 的 真实性 尽管 存在 着 争议 然而 轮台 罪己 诏 作为 中国 历... 1"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import pandas as pd\n",
|
||||
"\n",
|
||||
"df_train = pd.DataFrame(train_data, columns=[\"words\", \"label\"])\n",
|
||||
"df_test = pd.DataFrame(test_data, columns=[\"words\", \"label\"])\n",
|
||||
"df_train.head()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### 特征编码(Tf-Idf模型)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 56,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from sklearn.feature_extraction.text import TfidfVectorizer\n",
|
||||
"\n",
|
||||
"vectorizer = TfidfVectorizer(token_pattern='\\[?\\w+\\]?', \n",
|
||||
" stop_words=stopwords)\n",
|
||||
"X_train = vectorizer.fit_transform(df_train[\"words\"])\n",
|
||||
"y_train = df_train[\"label\"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 57,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"X_test = vectorizer.transform(df_test[\"words\"])\n",
|
||||
"y_test = df_test[\"label\"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### 训练模型&测试"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 58,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"SVC()"
|
||||
]
|
||||
},
|
||||
"execution_count": 58,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from sklearn import svm\n",
|
||||
"\n",
|
||||
"clf = svm.SVC()\n",
|
||||
"clf.fit(X_train, y_train)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 59,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# 在测试集上用模型预测结果\n",
|
||||
"y_pred = clf.predict(X_test)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 60,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" precision recall f1-score support\n",
|
||||
"\n",
|
||||
" 0 0.82 0.69 0.75 155\n",
|
||||
" 1 0.87 0.93 0.90 345\n",
|
||||
"\n",
|
||||
" accuracy 0.86 500\n",
|
||||
" macro avg 0.84 0.81 0.82 500\n",
|
||||
"weighted avg 0.85 0.86 0.85 500\n",
|
||||
"\n",
|
||||
"准确率: 0.856\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# 测试集效果检验\n",
|
||||
"from sklearn import metrics\n",
|
||||
"\n",
|
||||
"print(metrics.classification_report(y_test, y_pred))\n",
|
||||
"print(\"准确率:\", metrics.accuracy_score(y_test, y_pred))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### 手动输入句子,判断情感倾向"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 49,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from utils import processing\n",
|
||||
"\n",
|
||||
"strs = [\"只要流过的汗与泪都能化作往后的明亮,就值得你为自己喝彩\", \"烦死了!为什么周末还要加班[愤怒]\"]\n",
|
||||
"words = [processing(s) for s in strs]\n",
|
||||
"vec = vectorizer.transform(words)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 50,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"array([0, 0])"
|
||||
]
|
||||
},
|
||||
"execution_count": 50,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"output = clf.predict(vec)\n",
|
||||
"output"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
@@ -1,314 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### 加载数据集"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from utils import load_corpus, stopwords\n",
|
||||
"\n",
|
||||
"TRAIN_PATH = \"./data/weibo2018/train.txt\"\n",
|
||||
"TEST_PATH = \"./data/weibo2018/test.txt\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Building prefix dict from the default dictionary ...\n",
|
||||
"Dumping model to file cache /var/folders/rt/khjltk4j6n78x9x3f20hdr6m0000gp/T/jieba.cache\n",
|
||||
"Loading model cost 1.013 seconds.\n",
|
||||
"Prefix dict has been built successfully.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# 分别加载训练集和测试集\n",
|
||||
"train_data = load_corpus(TRAIN_PATH)\n",
|
||||
"test_data = load_corpus(TEST_PATH)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<div>\n",
|
||||
"<style scoped>\n",
|
||||
" .dataframe tbody tr th:only-of-type {\n",
|
||||
" vertical-align: middle;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe tbody tr th {\n",
|
||||
" vertical-align: top;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe thead th {\n",
|
||||
" text-align: right;\n",
|
||||
" }\n",
|
||||
"</style>\n",
|
||||
"<table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: right;\">\n",
|
||||
" <th></th>\n",
|
||||
" <th>words</th>\n",
|
||||
" <th>label</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>0</th>\n",
|
||||
" <td>书中 自有 黄金屋 书中 自有 颜如玉 沿着 岁月 的 长河 跋涉 或是 风光旖旎 或是 姹...</td>\n",
|
||||
" <td>1</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1</th>\n",
|
||||
" <td>这是 英超 被 黑 的 最惨 的 一次 二哈 二哈 十几年来 中国 只有 孙继海 董方卓 郑...</td>\n",
|
||||
" <td>0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>2</th>\n",
|
||||
" <td>中国 远洋 海运 集团 副总经理 俞曾 港 月 日 在 上 表示 中央 企业 走 出去 是 ...</td>\n",
|
||||
" <td>1</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>3</th>\n",
|
||||
" <td>看 流星花园 其实 也 还好 啦 现在 的 观念 以及 时尚 眼光 都 不一样 了 或许 十...</td>\n",
|
||||
" <td>1</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>4</th>\n",
|
||||
" <td>汉武帝 的 罪己 诏 的 真实性 尽管 存在 着 争议 然而 轮台 罪己 诏 作为 中国 历...</td>\n",
|
||||
" <td>1</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
"</div>"
|
||||
],
|
||||
"text/plain": [
|
||||
" words label\n",
|
||||
"0 书中 自有 黄金屋 书中 自有 颜如玉 沿着 岁月 的 长河 跋涉 或是 风光旖旎 或是 姹... 1\n",
|
||||
"1 这是 英超 被 黑 的 最惨 的 一次 二哈 二哈 十几年来 中国 只有 孙继海 董方卓 郑... 0\n",
|
||||
"2 中国 远洋 海运 集团 副总经理 俞曾 港 月 日 在 上 表示 中央 企业 走 出去 是 ... 1\n",
|
||||
"3 看 流星花园 其实 也 还好 啦 现在 的 观念 以及 时尚 眼光 都 不一样 了 或许 十... 1\n",
|
||||
"4 汉武帝 的 罪己 诏 的 真实性 尽管 存在 着 争议 然而 轮台 罪己 诏 作为 中国 历... 1"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import pandas as pd\n",
|
||||
"\n",
|
||||
"df_train = pd.DataFrame(train_data, columns=[\"words\", \"label\"])\n",
|
||||
"df_test = pd.DataFrame(test_data, columns=[\"words\", \"label\"])\n",
|
||||
"df_train.head()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### 特征编码"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/Users/albertdxq/opt/anaconda3/lib/python3.8/site-packages/sklearn/feature_extraction/text.py:383: UserWarning: Your stop_words may be inconsistent with your preprocessing. Tokenizing the stop words generated tokens ['元', '吨', '数', '末'] not in stop_words.\n",
|
||||
" warnings.warn('Your stop_words may be inconsistent with '\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from sklearn.feature_extraction.text import CountVectorizer\n",
|
||||
"\n",
|
||||
"vectorizer = CountVectorizer(token_pattern='\\[?\\w+\\]?', \n",
|
||||
" stop_words=stopwords,\n",
|
||||
" max_features=2000)\n",
|
||||
"X_train = vectorizer.fit_transform(df_train[\"words\"])\n",
|
||||
"y_train = df_train[\"label\"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"X_test = vectorizer.transform(df_test[\"words\"])\n",
|
||||
"y_test = df_test[\"label\"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### 训练模型&测试"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import xgboost as xgb\n",
|
||||
"\n",
|
||||
"param = {\n",
|
||||
" 'booster':'gbtree',\n",
|
||||
" 'max_depth': 6, \n",
|
||||
" 'scale_pos_weight': 0.5,\n",
|
||||
" 'colsample_bytree': 0.8,\n",
|
||||
" 'objective': 'binary:logistic',\n",
|
||||
" 'eval_metric': 'error',\n",
|
||||
" 'eta': 0.3,\n",
|
||||
" 'nthread': 10,\n",
|
||||
"}\n",
|
||||
"dmatrix = xgb.DMatrix(X_train, label=y_train)\n",
|
||||
"model = xgb.train(param, dmatrix, num_boost_round=200)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# 在测试集上用模型预测结果\n",
|
||||
"dmatrix = xgb.DMatrix(X_test)\n",
|
||||
"y_pred = model.predict(dmatrix)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"metadata": {
|
||||
"scrolled": false
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" precision recall f1-score support\n",
|
||||
"\n",
|
||||
" 0 0.75 0.82 0.78 155\n",
|
||||
" 1 0.92 0.88 0.90 345\n",
|
||||
"\n",
|
||||
" accuracy 0.86 500\n",
|
||||
" macro avg 0.83 0.85 0.84 500\n",
|
||||
"weighted avg 0.86 0.86 0.86 500\n",
|
||||
"\n",
|
||||
"准确率: 0.86\n",
|
||||
"AUC: 0.9040205703599813\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# 测试集效果检验\n",
|
||||
"from sklearn import metrics\n",
|
||||
"\n",
|
||||
"auc_score = metrics.roc_auc_score(y_test, y_pred) # 先计算AUC\n",
|
||||
"y_pred = list(map(lambda x:1 if x > 0.5 else 0, y_pred)) # 二值化\n",
|
||||
"print(metrics.classification_report(y_test, y_pred))\n",
|
||||
"print(\"准确率:\", metrics.accuracy_score(y_test, y_pred))\n",
|
||||
"print(\"AUC:\", auc_score)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### 手动输入句子,判断情感倾向(1正/0负)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from utils import processing\n",
|
||||
"\n",
|
||||
"strs = [\"哈哈哈哈哈笑死我了\", \"我也是有脾气的!\"]\n",
|
||||
"words = [processing(s) for s in strs]\n",
|
||||
"vec = vectorizer.transform(words)\n",
|
||||
"dmatrix = xgb.DMatrix(vec)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"array([0.8683682, 0.3285784], dtype=float32)"
|
||||
]
|
||||
},
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"output = model.predict(dmatrix)\n",
|
||||
"output"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
@@ -1,717 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"cell_id": 40
|
||||
},
|
||||
"source": [
|
||||
"### 加载数据集"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {
|
||||
"cell_id": 1
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from utils import load_corpus, stopwords\n",
|
||||
"\n",
|
||||
"TRAIN_PATH = \"./data/weibo2018/train.txt\"\n",
|
||||
"TEST_PATH = \"./data/weibo2018/test.txt\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {
|
||||
"cell_id": 2
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Building prefix dict from the default dictionary ...\n",
|
||||
"Loading model from cache /tmp/jieba.cache\n",
|
||||
"Loading model cost 0.826 seconds.\n",
|
||||
"Prefix dict has been built successfully.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# 分别加载训练集和测试集\n",
|
||||
"train_data = load_corpus(TRAIN_PATH)\n",
|
||||
"test_data = load_corpus(TEST_PATH)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {
|
||||
"cell_id": 3
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<div>\n",
|
||||
"<style scoped>\n",
|
||||
" .dataframe tbody tr th:only-of-type {\n",
|
||||
" vertical-align: middle;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe tbody tr th {\n",
|
||||
" vertical-align: top;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe thead th {\n",
|
||||
" text-align: right;\n",
|
||||
" }\n",
|
||||
"</style>\n",
|
||||
"<table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: right;\">\n",
|
||||
" <th></th>\n",
|
||||
" <th>text</th>\n",
|
||||
" <th>label</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>0</th>\n",
|
||||
" <td>书中 自有 黄金屋 书中 自有 颜如玉 沿着 岁月 的 长河 跋涉 或是 风光旖旎 或是 姹...</td>\n",
|
||||
" <td>1</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1</th>\n",
|
||||
" <td>这是 英超 被 黑 的 最惨 的 一次 二哈 二哈 十几年来 中国 只有 孙继海 董方卓 郑...</td>\n",
|
||||
" <td>0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>2</th>\n",
|
||||
" <td>中国 远洋 海运 集团 副总经理 俞曾 港 月 日 在 上 表示 中央 企业 走 出去 是 ...</td>\n",
|
||||
" <td>1</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>3</th>\n",
|
||||
" <td>看 流星花园 其实 也 还好 啦 现在 的 观念 以及 时尚 眼光 都 不一样 了 或许 十...</td>\n",
|
||||
" <td>1</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>4</th>\n",
|
||||
" <td>汉武帝 的 罪己 诏 的 真实性 尽管 存在 着 争议 然而 轮台 罪己 诏 作为 中国 历...</td>\n",
|
||||
" <td>1</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
"</div>"
|
||||
],
|
||||
"text/plain": [
|
||||
" text label\n",
|
||||
"0 书中 自有 黄金屋 书中 自有 颜如玉 沿着 岁月 的 长河 跋涉 或是 风光旖旎 或是 姹... 1\n",
|
||||
"1 这是 英超 被 黑 的 最惨 的 一次 二哈 二哈 十几年来 中国 只有 孙继海 董方卓 郑... 0\n",
|
||||
"2 中国 远洋 海运 集团 副总经理 俞曾 港 月 日 在 上 表示 中央 企业 走 出去 是 ... 1\n",
|
||||
"3 看 流星花园 其实 也 还好 啦 现在 的 观念 以及 时尚 眼光 都 不一样 了 或许 十... 1\n",
|
||||
"4 汉武帝 的 罪己 诏 的 真实性 尽管 存在 着 争议 然而 轮台 罪己 诏 作为 中国 历... 1"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import pandas as pd\n",
|
||||
"\n",
|
||||
"df_train = pd.DataFrame(train_data, columns=[\"text\", \"label\"])\n",
|
||||
"df_test = pd.DataFrame(test_data, columns=[\"text\", \"label\"])\n",
|
||||
"df_train.head()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"cell_id": 42
|
||||
},
|
||||
"source": [
|
||||
"### 训练词向量"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {
|
||||
"cell_id": 44
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"0 [书中, 自有, 黄金屋, 书中, 自有, 颜如玉, 沿着, 岁月, 的, 长河, 跋涉, ...\n",
|
||||
"1 [这是, 英超, 被, 黑, 的, 最惨, 的, 一次, 二哈, 二哈, 十几年来, 中国,...\n",
|
||||
"2 [中国, 远洋, 海运, 集团, 副总经理, 俞曾, 港, 月, 日, 在, 上, 表示, ...\n",
|
||||
"3 [看, 流星花园, 其实, 也, 还好, 啦, 现在, 的, 观念, 以及, 时尚, 眼光,...\n",
|
||||
"4 [汉武帝, 的, 罪己, 诏, 的, 真实性, 尽管, 存在, 着, 争议, 然而, 轮台,...\n",
|
||||
"Name: text, dtype: object"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# word2vec要求的输入格式: list(word)\n",
|
||||
"wv_input = df_train['text'].map(lambda s: s.split(\" \")) # [for w in s.split(\" \") if w not in stopwords]\n",
|
||||
"wv_input.head() "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {
|
||||
"cell_id": 4
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/home/tiger/.local/lib/python3.7/site-packages/gensim/similarities/__init__.py:15: UserWarning: The gensim.similarities.levenshtein submodule is disabled, because the optional Levenshtein package <https://pypi.org/project/python-Levenshtein/> is unavailable. Install Levenhstein (e.g. `pip install python-Levenshtein`) to suppress this warning.\n",
|
||||
" warnings.warn(msg)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from gensim import models\n",
|
||||
"\n",
|
||||
"# Word2Vec\n",
|
||||
"word2vec = models.Word2Vec(wv_input, \n",
|
||||
" vector_size=64, # 词向量维度\n",
|
||||
" min_count=1, # 最小词频, 因为数据量较小, 这里卡1\n",
|
||||
" epochs=1000) # 迭代轮次"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"cell_id": 46
|
||||
},
|
||||
"source": [
|
||||
"查找近义词, 直观感受训练得到的word2vec效果"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {
|
||||
"cell_id": 5,
|
||||
"scrolled": false
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[('我', 0.9441561102867126),\n",
|
||||
" ('自己', 0.8928312659263611),\n",
|
||||
" ('他', 0.8796129822731018),\n",
|
||||
" ('的', 0.8601957559585571),\n",
|
||||
" ('她', 0.855070948600769),\n",
|
||||
" ('人', 0.8349815607070923),\n",
|
||||
" ('都', 0.8168802261352539),\n",
|
||||
" ('了', 0.8017680644989014),\n",
|
||||
" ('就', 0.7990766763687134),\n",
|
||||
" ('也', 0.7883183360099792)]"
|
||||
]
|
||||
},
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"word2vec.wv.most_similar(\"你\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {
|
||||
"cell_id": 38
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[('哈哈哈', 0.6309624910354614),\n",
|
||||
" ('啦', 0.5457888841629028),\n",
|
||||
" ('可爱', 0.5375339984893799),\n",
|
||||
" ('了', 0.4885959327220917),\n",
|
||||
" ('本柔', 0.46517741680145264),\n",
|
||||
" ('笑', 0.4639575779438019),\n",
|
||||
" ('哈哈哈哈', 0.45851588249206543),\n",
|
||||
" ('心虚', 0.4576280415058136),\n",
|
||||
" ('又', 0.45520466566085815),\n",
|
||||
" ('呀', 0.4494859576225281)]"
|
||||
]
|
||||
},
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"word2vec.wv.most_similar(\"哈哈\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {
|
||||
"cell_id": 39
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[('难过', 0.724579393863678),\n",
|
||||
" ('哭', 0.6421604752540588),\n",
|
||||
" ('想', 0.6415957808494568),\n",
|
||||
" ('也', 0.6394745707511902),\n",
|
||||
" ('真的', 0.6263709664344788),\n",
|
||||
" ('我', 0.6136066317558289),\n",
|
||||
" ('都', 0.608888566493988),\n",
|
||||
" ('的', 0.6078368425369263),\n",
|
||||
" ('就', 0.5916700959205627),\n",
|
||||
" ('开心', 0.5899774432182312)]"
|
||||
]
|
||||
},
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"word2vec.wv.most_similar(\"伤心\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"cell_id": 48
|
||||
},
|
||||
"source": [
|
||||
"### 神经网络"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"metadata": {
|
||||
"cell_id": 14
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"from torch import nn\n",
|
||||
"from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence,pad_packed_sequence\n",
|
||||
"from torch.utils.data import Dataset, DataLoader\n",
|
||||
"\n",
|
||||
"device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {
|
||||
"cell_id": 19
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# 超参数\n",
|
||||
"learning_rate = 5e-4\n",
|
||||
"input_size = 768\n",
|
||||
"num_epoches = 5\n",
|
||||
"batch_size = 100\n",
|
||||
"embed_size = 64\n",
|
||||
"hidden_size = 64\n",
|
||||
"num_layers = 2"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"metadata": {
|
||||
"cell_id": 7
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# 数据集\n",
|
||||
"class MyDataset(Dataset):\n",
|
||||
" def __init__(self, df):\n",
|
||||
" self.data = []\n",
|
||||
" self.label = df[\"label\"].tolist()\n",
|
||||
" for s in df[\"text\"].tolist():\n",
|
||||
" vectors = []\n",
|
||||
" for w in s.split(\" \"):\n",
|
||||
" if w in word2vec.wv.key_to_index:\n",
|
||||
" vectors.append(word2vec.wv[w]) # 将每个词替换为对应的词向量\n",
|
||||
" vectors = torch.Tensor(vectors)\n",
|
||||
" self.data.append(vectors)\n",
|
||||
" \n",
|
||||
" def __getitem__(self, index):\n",
|
||||
" data = self.data[index]\n",
|
||||
" label = self.label[index]\n",
|
||||
" return data, label\n",
|
||||
"\n",
|
||||
" def __len__(self):\n",
|
||||
" return len(self.label)\n",
|
||||
"\n",
|
||||
"def collate_fn(data):\n",
|
||||
" \"\"\"\n",
|
||||
" :param data: 第0维:data,第1维:label\n",
|
||||
" :return: 序列化的data、记录实际长度的序列、以及label列表\n",
|
||||
" \"\"\"\n",
|
||||
" data.sort(key=lambda x: len(x[0]), reverse=True) # pack_padded_sequence要求要按照序列的长度倒序排列\n",
|
||||
" data_length = [len(sq[0]) for sq in data]\n",
|
||||
" x = [i[0] for i in data]\n",
|
||||
" y = [i[1] for i in data]\n",
|
||||
" data = pad_sequence(x, batch_first=True, padding_value=0) # 用RNN处理变长序列的必要操作\n",
|
||||
" return data, torch.tensor(y, dtype=torch.float32), data_length\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# 训练集\n",
|
||||
"train_data = MyDataset(df_train)\n",
|
||||
"train_loader = DataLoader(train_data, batch_size=batch_size, collate_fn=collate_fn, shuffle=True)\n",
|
||||
"\n",
|
||||
"# 测试集\n",
|
||||
"test_data = MyDataset(df_test)\n",
|
||||
"test_loader = DataLoader(test_data, batch_size=batch_size, collate_fn=collate_fn, shuffle=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"metadata": {
|
||||
"cell_id": 11
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# 网络结构\n",
|
||||
"class LSTM(nn.Module):\n",
|
||||
" def __init__(self, input_size, hidden_size, num_layers):\n",
|
||||
" super(LSTM, self).__init__()\n",
|
||||
" self.hidden_size = hidden_size\n",
|
||||
" self.num_layers = num_layers\n",
|
||||
" self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bidirectional=True)\n",
|
||||
" self.fc = nn.Linear(hidden_size * 2, 1) # 双向, 输出维度要*2\n",
|
||||
" self.sigmoid = nn.Sigmoid()\n",
|
||||
"\n",
|
||||
" def forward(self, x, lengths):\n",
|
||||
" h0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(device) # 双向, 第一个维度要*2\n",
|
||||
" c0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(device)\n",
|
||||
" \n",
|
||||
" packed_input = torch.nn.utils.rnn.pack_padded_sequence(input=x, lengths=lengths, batch_first=True)\n",
|
||||
" packed_out, (h_n, h_c) = self.lstm(packed_input, (h0, c0))\n",
|
||||
"\n",
|
||||
" lstm_out = torch.cat([h_n[-2], h_n[-1]], 1) # 双向, 所以要将最后两维拼接, 得到的就是最后一个time step的输出\n",
|
||||
" out = self.fc(lstm_out)\n",
|
||||
" out = self.sigmoid(out)\n",
|
||||
" return out\n",
|
||||
"\n",
|
||||
"lstm = LSTM(embed_size, hidden_size, num_layers)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"metadata": {
|
||||
"cell_id": 26
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from sklearn import metrics\n",
|
||||
"\n",
|
||||
"# 在测试集效果检验\n",
|
||||
"def test():\n",
|
||||
" y_pred, y_true = [], []\n",
|
||||
"\n",
|
||||
" with torch.no_grad():\n",
|
||||
" for x, labels, lengths in test_loader:\n",
|
||||
" x = x.to(device)\n",
|
||||
" outputs = lstm(x, lengths) # 前向传播\n",
|
||||
" outputs = outputs.view(-1) # 将输出展平\n",
|
||||
" y_pred.append(outputs)\n",
|
||||
" y_true.append(labels)\n",
|
||||
"\n",
|
||||
" y_prob = torch.cat(y_pred)\n",
|
||||
" y_true = torch.cat(y_true)\n",
|
||||
" y_pred = y_prob.clone()\n",
|
||||
" y_pred[y_pred > 0.5] = 1\n",
|
||||
" y_pred[y_pred <= 0.5] = 0\n",
|
||||
" \n",
|
||||
" print(metrics.classification_report(y_true, y_pred))\n",
|
||||
" print(\"准确率:\", metrics.accuracy_score(y_true, y_pred))\n",
|
||||
" print(\"AUC:\", metrics.roc_auc_score(y_true, y_prob) )"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"metadata": {
|
||||
"cell_id": 32
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# 定义损失函数和优化器\n",
|
||||
"criterion = nn.BCELoss()\n",
|
||||
"optimizer = torch.optim.Adam(lstm.parameters(), lr=learning_rate)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"metadata": {
|
||||
"cell_id": 33,
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"epoch:1, step:10, loss:0.689099133014679\n",
|
||||
"epoch:1, step:20, loss:0.6717442870140076\n",
|
||||
"epoch:1, step:30, loss:0.650161862373352\n",
|
||||
"epoch:1, step:40, loss:0.5935518741607666\n",
|
||||
"epoch:1, step:50, loss:0.4994719922542572\n",
|
||||
"epoch:1, step:60, loss:0.4774974286556244\n",
|
||||
"epoch:1, step:70, loss:0.482360303401947\n",
|
||||
"epoch:1, step:80, loss:0.44858306646347046\n",
|
||||
"epoch:1, step:90, loss:0.4513603746891022\n",
|
||||
"epoch:1, step:100, loss:0.4386572241783142\n",
|
||||
" precision recall f1-score support\n",
|
||||
"\n",
|
||||
" 0.0 0.75 0.80 0.78 155\n",
|
||||
" 1.0 0.91 0.88 0.89 345\n",
|
||||
"\n",
|
||||
" accuracy 0.86 500\n",
|
||||
" macro avg 0.83 0.84 0.83 500\n",
|
||||
"weighted avg 0.86 0.86 0.86 500\n",
|
||||
"\n",
|
||||
"准确率: 0.856\n",
|
||||
"AUC: 0.9141841982234689\n",
|
||||
"saved model: ./model/lstm_1.model\n",
|
||||
"epoch:2, step:10, loss:0.4317778944969177\n",
|
||||
"epoch:2, step:20, loss:0.41387200355529785\n",
|
||||
"epoch:2, step:30, loss:0.4237545430660248\n",
|
||||
"epoch:2, step:40, loss:0.364933043718338\n",
|
||||
"epoch:2, step:50, loss:0.37595903873443604\n",
|
||||
"epoch:2, step:60, loss:0.4067295491695404\n",
|
||||
"epoch:2, step:70, loss:0.41071224212646484\n",
|
||||
"epoch:2, step:80, loss:0.39134103059768677\n",
|
||||
"epoch:2, step:90, loss:0.37907883524894714\n",
|
||||
"epoch:2, step:100, loss:0.4322803020477295\n",
|
||||
" precision recall f1-score support\n",
|
||||
"\n",
|
||||
" 0.0 0.80 0.63 0.71 155\n",
|
||||
" 1.0 0.85 0.93 0.89 345\n",
|
||||
"\n",
|
||||
" accuracy 0.84 500\n",
|
||||
" macro avg 0.83 0.78 0.80 500\n",
|
||||
"weighted avg 0.83 0.84 0.83 500\n",
|
||||
"\n",
|
||||
"准确率: 0.838\n",
|
||||
"AUC: 0.9174193548387096\n",
|
||||
"saved model: ./model/lstm_2.model\n",
|
||||
"epoch:3, step:10, loss:0.37696003913879395\n",
|
||||
"epoch:3, step:20, loss:0.36385685205459595\n",
|
||||
"epoch:3, step:30, loss:0.3907310664653778\n",
|
||||
"epoch:3, step:40, loss:0.35576874017715454\n",
|
||||
"epoch:3, step:50, loss:0.36152324080467224\n",
|
||||
"epoch:3, step:60, loss:0.3620041608810425\n",
|
||||
"epoch:3, step:70, loss:0.32647013664245605\n",
|
||||
"epoch:3, step:80, loss:0.38903307914733887\n",
|
||||
"epoch:3, step:90, loss:0.34238141775131226\n",
|
||||
"epoch:3, step:100, loss:0.3952549397945404\n",
|
||||
" precision recall f1-score support\n",
|
||||
"\n",
|
||||
" 0.0 0.75 0.79 0.77 155\n",
|
||||
" 1.0 0.90 0.88 0.89 345\n",
|
||||
"\n",
|
||||
" accuracy 0.85 500\n",
|
||||
" macro avg 0.83 0.84 0.83 500\n",
|
||||
"weighted avg 0.86 0.85 0.85 500\n",
|
||||
"\n",
|
||||
"准确率: 0.854\n",
|
||||
"AUC: 0.9280411407199626\n",
|
||||
"saved model: ./model/lstm_3.model\n",
|
||||
"epoch:4, step:10, loss:0.34902292490005493\n",
|
||||
"epoch:4, step:20, loss:0.3277026116847992\n",
|
||||
"epoch:4, step:30, loss:0.32119297981262207\n",
|
||||
"epoch:4, step:40, loss:0.34501412510871887\n",
|
||||
"epoch:4, step:50, loss:0.3202686905860901\n",
|
||||
"epoch:4, step:60, loss:0.3599391579627991\n",
|
||||
"epoch:4, step:70, loss:0.2958642542362213\n",
|
||||
"epoch:4, step:80, loss:0.3152882158756256\n",
|
||||
"epoch:4, step:90, loss:0.3151417374610901\n",
|
||||
"epoch:4, step:100, loss:0.3314781188964844\n",
|
||||
" precision recall f1-score support\n",
|
||||
"\n",
|
||||
" 0.0 0.78 0.81 0.79 155\n",
|
||||
" 1.0 0.91 0.90 0.90 345\n",
|
||||
"\n",
|
||||
" accuracy 0.87 500\n",
|
||||
" macro avg 0.84 0.85 0.85 500\n",
|
||||
"weighted avg 0.87 0.87 0.87 500\n",
|
||||
"\n",
|
||||
"准确率: 0.868\n",
|
||||
"AUC: 0.9314258999532491\n",
|
||||
"saved model: ./model/lstm_4.model\n",
|
||||
"epoch:5, step:10, loss:0.2638005316257477\n",
|
||||
"epoch:5, step:20, loss:0.3028942048549652\n",
|
||||
"epoch:5, step:30, loss:0.2819410562515259\n",
|
||||
"epoch:5, step:40, loss:0.2857419550418854\n",
|
||||
"epoch:5, step:50, loss:0.3177730441093445\n",
|
||||
"epoch:5, step:60, loss:0.3140687346458435\n",
|
||||
"epoch:5, step:70, loss:0.32480892539024353\n",
|
||||
"epoch:5, step:80, loss:0.2964351177215576\n",
|
||||
"epoch:5, step:90, loss:0.27567631006240845\n",
|
||||
"epoch:5, step:100, loss:0.2848973870277405\n",
|
||||
" precision recall f1-score support\n",
|
||||
"\n",
|
||||
" 0.0 0.83 0.74 0.78 155\n",
|
||||
" 1.0 0.89 0.93 0.91 345\n",
|
||||
"\n",
|
||||
" accuracy 0.87 500\n",
|
||||
" macro avg 0.86 0.83 0.84 500\n",
|
||||
"weighted avg 0.87 0.87 0.87 500\n",
|
||||
"\n",
|
||||
"准确率: 0.87\n",
|
||||
"AUC: 0.9310892940626461\n",
|
||||
"saved model: ./model/lstm_5.model\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# 迭代训练\n",
|
||||
"for epoch in range(num_epoches):\n",
|
||||
" total_loss = 0\n",
|
||||
" for i, (x, labels, lengths) in enumerate(train_loader):\n",
|
||||
" x = x.to(device)\n",
|
||||
" labels = labels.to(device)\n",
|
||||
" outputs = lstm(x, lengths) # 前向传播\n",
|
||||
" logits = outputs.view(-1) # 将输出展平\n",
|
||||
" loss = criterion(logits, labels) # loss计算\n",
|
||||
" total_loss += loss\n",
|
||||
" optimizer.zero_grad() # 梯度清零\n",
|
||||
" loss.backward(retain_graph=True) # 反向传播,计算梯度\n",
|
||||
" optimizer.step() # 梯度更新\n",
|
||||
" if (i+1) % 10 == 0:\n",
|
||||
" print(\"epoch:{}, step:{}, loss:{}\".format(epoch+1, i+1, total_loss/10))\n",
|
||||
" total_loss = 0\n",
|
||||
" \n",
|
||||
" # test\n",
|
||||
" test()\n",
|
||||
" \n",
|
||||
" # save model\n",
|
||||
" model_path = \"./model/lstm_{}.model\".format(epoch+1)\n",
|
||||
" torch.save(lstm, model_path)\n",
|
||||
" print(\"saved model: \", model_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"cell_id": 36
|
||||
},
|
||||
"source": [
|
||||
"### 手动输入句子,判断情感倾向(1正/0负)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"metadata": {
|
||||
"cell_id": 51
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"net = torch.load(\"./model/lstm_5.model\") # 训练过程中的巅峰时刻"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 26,
|
||||
"metadata": {
|
||||
"cell_id": 52
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"tensor([0.9657, 0.3921])"
|
||||
]
|
||||
},
|
||||
"execution_count": 26,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from utils import processing\n",
|
||||
"\n",
|
||||
"strs = [\"我想说我会爱你多一点点\", \"日有所思梦感伤\"]\n",
|
||||
"\n",
|
||||
"data = []\n",
|
||||
"for s in strs:\n",
|
||||
" vectors = []\n",
|
||||
" for w in processing(s).split(\" \"):\n",
|
||||
" if w in word2vec.wv.key_to_index:\n",
|
||||
" vectors.append(word2vec.wv[w]) # 将每个词替换为对应的词向量\n",
|
||||
" vectors = torch.Tensor(vectors)\n",
|
||||
" data.append(vectors)\n",
|
||||
"x, _, lengths = collate_fn(list(zip(data, [-1] * len(strs))))\n",
|
||||
"with torch.no_grad():\n",
|
||||
" x = x.to(device)\n",
|
||||
" outputs = lstm(x, lengths) # 前向传播\n",
|
||||
" outputs = outputs.view(-1) # 将输出展平\n",
|
||||
"outputs"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"cell_id": 54
|
||||
},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.5"
|
||||
},
|
||||
"max_cell_id": 55
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -1,696 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"cell_id": 39
|
||||
},
|
||||
"source": [
|
||||
"### 加载数据集"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {
|
||||
"cell_id": 1
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from utils import load_corpus_bert\n",
|
||||
"\n",
|
||||
"TRAIN_PATH = \"./data/weibo2018/train.txt\"\n",
|
||||
"TEST_PATH = \"./data/weibo2018/test.txt\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {
|
||||
"cell_id": 3
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# 分别加载训练集和测试集\n",
|
||||
"train_data = load_corpus_bert(TRAIN_PATH)\n",
|
||||
"test_data = load_corpus_bert(TEST_PATH)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {
|
||||
"cell_id": 4
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<div>\n",
|
||||
"<style scoped>\n",
|
||||
" .dataframe tbody tr th:only-of-type {\n",
|
||||
" vertical-align: middle;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe tbody tr th {\n",
|
||||
" vertical-align: top;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe thead th {\n",
|
||||
" text-align: right;\n",
|
||||
" }\n",
|
||||
"</style>\n",
|
||||
"<table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: right;\">\n",
|
||||
" <th></th>\n",
|
||||
" <th>text</th>\n",
|
||||
" <th>label</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>0</th>\n",
|
||||
" <td>“书中自有黄金屋,书中自有颜如玉”。沿着岁月的长河跋涉,或是风光旖旎,或是姹紫嫣红,万千...</td>\n",
|
||||
" <td>1</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1</th>\n",
|
||||
" <td>这是英超被黑的最惨的一次[二哈][二哈]十几年来,中国只有孙继海,董方卓,郑智,李铁登陆过英...</td>\n",
|
||||
" <td>0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>2</th>\n",
|
||||
" <td>中国远洋海运集团副总经理俞曾港4月21日在 上表示,中央企业“走出去”是要站在更高的平台参...</td>\n",
|
||||
" <td>1</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>3</th>\n",
|
||||
" <td>看《流星花园》其实也还好啦,现在的观念以及时尚眼光都不一样了,或许十几年之后的人看我们的现在...</td>\n",
|
||||
" <td>1</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>4</th>\n",
|
||||
" <td>汉武帝的罪己诏的真实性尽管存在着争议,然而“轮台罪己诏”作为中国历史上第一份皇帝自我批评的文...</td>\n",
|
||||
" <td>1</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
"</div>"
|
||||
],
|
||||
"text/plain": [
|
||||
" text label\n",
|
||||
"0 “书中自有黄金屋,书中自有颜如玉”。沿着岁月的长河跋涉,或是风光旖旎,或是姹紫嫣红,万千... 1\n",
|
||||
"1 这是英超被黑的最惨的一次[二哈][二哈]十几年来,中国只有孙继海,董方卓,郑智,李铁登陆过英... 0\n",
|
||||
"2 中国远洋海运集团副总经理俞曾港4月21日在 上表示,中央企业“走出去”是要站在更高的平台参... 1\n",
|
||||
"3 看《流星花园》其实也还好啦,现在的观念以及时尚眼光都不一样了,或许十几年之后的人看我们的现在... 1\n",
|
||||
"4 汉武帝的罪己诏的真实性尽管存在着争议,然而“轮台罪己诏”作为中国历史上第一份皇帝自我批评的文... 1"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import pandas as pd\n",
|
||||
"\n",
|
||||
"df_train = pd.DataFrame(train_data, columns=[\"text\", \"label\"])\n",
|
||||
"df_test = pd.DataFrame(test_data, columns=[\"text\", \"label\"])\n",
|
||||
"df_train.head()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"cell_id": 41
|
||||
},
|
||||
"source": [
|
||||
"### 加载Bert"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {
|
||||
"cell_id": 5
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"from transformers import BertTokenizer, BertModel\n",
|
||||
"\n",
|
||||
"os.environ[\"KMP_DUPLICATE_LIB_OK\"] = \"TRUE\" # 在我的电脑上不加这一句, bert模型会报错\n",
|
||||
"MODEL_PATH = \"./model/chinese_wwm_pytorch\" # 下载地址见 https://github.com/ymcui/Chinese-BERT-wwm"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {
|
||||
"cell_id": 6
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Some weights of the model checkpoint at ./model/chinese_wwm_pytorch were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight']\n",
|
||||
"- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
|
||||
"- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# 加载\n",
|
||||
"tokenizer = BertTokenizer.from_pretrained(MODEL_PATH) # 分词器\n",
|
||||
"bert = BertModel.from_pretrained(MODEL_PATH) # 模型"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"cell_id": 43
|
||||
},
|
||||
"source": [
|
||||
"### 神经网络"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {
|
||||
"cell_id": 7
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"from torch import nn\n",
|
||||
"from torch.utils.data import Dataset, DataLoader\n",
|
||||
"\n",
|
||||
"device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {
|
||||
"cell_id": 8
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# 超参数\n",
|
||||
"learning_rate = 1e-3\n",
|
||||
"input_size = 768\n",
|
||||
"num_epoches = 10\n",
|
||||
"batch_size = 100\n",
|
||||
"decay_rate = 0.9"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {
|
||||
"cell_id": 9
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# 数据集\n",
|
||||
"class MyDataset(Dataset):\n",
|
||||
" def __init__(self, df):\n",
|
||||
" self.data = df[\"text\"].tolist()\n",
|
||||
" self.label = df[\"label\"].tolist()\n",
|
||||
"\n",
|
||||
" def __getitem__(self, index):\n",
|
||||
" data = self.data[index]\n",
|
||||
" label = self.label[index]\n",
|
||||
" return data, label\n",
|
||||
"\n",
|
||||
" def __len__(self):\n",
|
||||
" return len(self.label)\n",
|
||||
"\n",
|
||||
"# 训练集\n",
|
||||
"train_data = MyDataset(df_train)\n",
|
||||
"train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)\n",
|
||||
"\n",
|
||||
"# 测试集\n",
|
||||
"test_data = MyDataset(df_test)\n",
|
||||
"test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"metadata": {
|
||||
"cell_id": 10
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# 网络结构\n",
|
||||
"class Net(nn.Module):\n",
|
||||
" def __init__(self, input_size):\n",
|
||||
" super(Net, self).__init__()\n",
|
||||
" self.fc = nn.Linear(input_size, 1)\n",
|
||||
" self.sigmoid = nn.Sigmoid()\n",
|
||||
"\n",
|
||||
" def forward(self, x):\n",
|
||||
" out = self.fc(x)\n",
|
||||
" out = self.sigmoid(out)\n",
|
||||
" return out\n",
|
||||
"\n",
|
||||
"net = Net(input_size).to(device)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {
|
||||
"cell_id": 34
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from sklearn import metrics\n",
|
||||
"\n",
|
||||
"# 测试集效果检验\n",
|
||||
"def test():\n",
|
||||
" y_pred, y_true = [], []\n",
|
||||
"\n",
|
||||
" with torch.no_grad():\n",
|
||||
" for words, labels in test_loader:\n",
|
||||
" tokens = tokenizer(words, padding=True)\n",
|
||||
" input_ids = torch.tensor(tokens[\"input_ids\"]).to(device)\n",
|
||||
" attention_mask = torch.tensor(tokens[\"attention_mask\"]).to(device)\n",
|
||||
" last_hidden_states = bert(input_ids, attention_mask=attention_mask)\n",
|
||||
" bert_output = last_hidden_states[0][:, 0]\n",
|
||||
" outputs = net(bert_output) # 前向传播\n",
|
||||
" outputs = outputs.view(-1) # 将输出展平\n",
|
||||
" y_pred.append(outputs)\n",
|
||||
" y_true.append(labels)\n",
|
||||
"\n",
|
||||
" y_prob = torch.cat(y_pred)\n",
|
||||
" y_true = torch.cat(y_true)\n",
|
||||
" y_pred = y_prob.clone()\n",
|
||||
" y_pred[y_pred > 0.5] = 1\n",
|
||||
" y_pred[y_pred <= 0.5] = 0\n",
|
||||
" \n",
|
||||
" print(metrics.classification_report(y_true, y_pred))\n",
|
||||
" print(\"准确率:\", metrics.accuracy_score(y_true, y_pred))\n",
|
||||
" print(\"AUC:\", metrics.roc_auc_score(y_true, y_prob) )"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"metadata": {
|
||||
"cell_id": 11
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# 定义损失函数和优化器\n",
|
||||
"criterion = nn.BCELoss()\n",
|
||||
"optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)\n",
|
||||
"scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=decay_rate)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"metadata": {
|
||||
"cell_id": 14,
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"epoch:1, step:10, loss:0.6710587739944458\n",
|
||||
"epoch:1, step:20, loss:0.6176288723945618\n",
|
||||
"epoch:1, step:30, loss:0.578593909740448\n",
|
||||
"epoch:1, step:40, loss:0.5502474308013916\n",
|
||||
"epoch:1, step:50, loss:0.5323082804679871\n",
|
||||
"epoch:1, step:60, loss:0.515110194683075\n",
|
||||
"epoch:1, step:70, loss:0.5127577185630798\n",
|
||||
"epoch:1, step:80, loss:0.48992329835891724\n",
|
||||
"epoch:1, step:90, loss:0.4868148863315582\n",
|
||||
"epoch:1, step:100, loss:0.49194520711898804\n",
|
||||
" precision recall f1-score support\n",
|
||||
"\n",
|
||||
" 0 0.74 0.74 0.74 155\n",
|
||||
" 1 0.88 0.88 0.88 345\n",
|
||||
"\n",
|
||||
" accuracy 0.84 500\n",
|
||||
" macro avg 0.81 0.81 0.81 500\n",
|
||||
"weighted avg 0.84 0.84 0.84 500\n",
|
||||
"\n",
|
||||
"准确率: 0.84\n",
|
||||
"AUC: 0.9027582982702197\n",
|
||||
"saved model: ./model/bert_dnn_1.model\n",
|
||||
"epoch:2, step:10, loss:0.46188774704933167\n",
|
||||
"epoch:2, step:20, loss:0.4335215985774994\n",
|
||||
"epoch:2, step:30, loss:0.4540901184082031\n",
|
||||
"epoch:2, step:40, loss:0.4392821788787842\n",
|
||||
"epoch:2, step:50, loss:0.47116056084632874\n",
|
||||
"epoch:2, step:60, loss:0.4669877886772156\n",
|
||||
"epoch:2, step:70, loss:0.4401330053806305\n",
|
||||
"epoch:2, step:80, loss:0.4518135190010071\n",
|
||||
"epoch:2, step:90, loss:0.4567466676235199\n",
|
||||
"epoch:2, step:100, loss:0.4663034975528717\n",
|
||||
" precision recall f1-score support\n",
|
||||
"\n",
|
||||
" 0 0.72 0.83 0.77 155\n",
|
||||
" 1 0.92 0.85 0.88 345\n",
|
||||
"\n",
|
||||
" accuracy 0.85 500\n",
|
||||
" macro avg 0.82 0.84 0.83 500\n",
|
||||
"weighted avg 0.86 0.85 0.85 500\n",
|
||||
"\n",
|
||||
"准确率: 0.846\n",
|
||||
"AUC: 0.9149322113136981\n",
|
||||
"saved model: ./model/bert_dnn_2.model\n",
|
||||
"epoch:3, step:10, loss:0.42892661690711975\n",
|
||||
"epoch:3, step:20, loss:0.4225884974002838\n",
|
||||
"epoch:3, step:30, loss:0.415252685546875\n",
|
||||
"epoch:3, step:40, loss:0.43130287528038025\n",
|
||||
"epoch:3, step:50, loss:0.42938193678855896\n",
|
||||
"epoch:3, step:60, loss:0.4340507388114929\n",
|
||||
"epoch:3, step:70, loss:0.4466826319694519\n",
|
||||
"epoch:3, step:80, loss:0.45244288444519043\n",
|
||||
"epoch:3, step:90, loss:0.41808539628982544\n",
|
||||
"epoch:3, step:100, loss:0.44330015778541565\n",
|
||||
" precision recall f1-score support\n",
|
||||
"\n",
|
||||
" 0 0.73 0.83 0.77 155\n",
|
||||
" 1 0.92 0.86 0.89 345\n",
|
||||
"\n",
|
||||
" accuracy 0.85 500\n",
|
||||
" macro avg 0.82 0.84 0.83 500\n",
|
||||
"weighted avg 0.86 0.85 0.85 500\n",
|
||||
"\n",
|
||||
"准确率: 0.85\n",
|
||||
"AUC: 0.9206545114539505\n",
|
||||
"saved model: ./model/bert_dnn_3.model\n",
|
||||
"epoch:4, step:10, loss:0.39769938588142395\n",
|
||||
"epoch:4, step:20, loss:0.4465697407722473\n",
|
||||
"epoch:4, step:30, loss:0.4216257929801941\n",
|
||||
"epoch:4, step:40, loss:0.41328248381614685\n",
|
||||
"epoch:4, step:50, loss:0.41364049911499023\n",
|
||||
"epoch:4, step:60, loss:0.4332212507724762\n",
|
||||
"epoch:4, step:70, loss:0.4280005395412445\n",
|
||||
"epoch:4, step:80, loss:0.41606149077415466\n",
|
||||
"epoch:4, step:90, loss:0.43310579657554626\n",
|
||||
"epoch:4, step:100, loss:0.4076871871948242\n",
|
||||
" precision recall f1-score support\n",
|
||||
"\n",
|
||||
" 0 0.76 0.80 0.78 155\n",
|
||||
" 1 0.91 0.88 0.90 345\n",
|
||||
"\n",
|
||||
" accuracy 0.86 500\n",
|
||||
" macro avg 0.83 0.84 0.84 500\n",
|
||||
"weighted avg 0.86 0.86 0.86 500\n",
|
||||
"\n",
|
||||
"准确率: 0.858\n",
|
||||
"AUC: 0.9222814399251986\n",
|
||||
"saved model: ./model/bert_dnn_4.model\n",
|
||||
"epoch:5, step:10, loss:0.39923620223999023\n",
|
||||
"epoch:5, step:20, loss:0.4110904633998871\n",
|
||||
"epoch:5, step:30, loss:0.4446052610874176\n",
|
||||
"epoch:5, step:40, loss:0.4050986170768738\n",
|
||||
"epoch:5, step:50, loss:0.41362982988357544\n",
|
||||
"epoch:5, step:60, loss:0.3961515724658966\n",
|
||||
"epoch:5, step:80, loss:0.43208274245262146\n",
|
||||
"epoch:5, step:90, loss:0.4123595356941223\n",
|
||||
"epoch:5, step:100, loss:0.4114747643470764\n",
|
||||
" precision recall f1-score support\n",
|
||||
"\n",
|
||||
" 0 0.75 0.81 0.78 155\n",
|
||||
" 1 0.91 0.88 0.90 345\n",
|
||||
"\n",
|
||||
" accuracy 0.86 500\n",
|
||||
" macro avg 0.83 0.85 0.84 500\n",
|
||||
"weighted avg 0.86 0.86 0.86 500\n",
|
||||
"\n",
|
||||
"准确率: 0.86\n",
|
||||
"AUC: 0.9251238896680692\n",
|
||||
"saved model: ./model/bert_dnn_5.model\n",
|
||||
"epoch:6, step:10, loss:0.4047953188419342\n",
|
||||
"epoch:6, step:20, loss:0.41434162855148315\n",
|
||||
"epoch:6, step:30, loss:0.4052816927433014\n",
|
||||
"epoch:6, step:40, loss:0.3726503849029541\n",
|
||||
"epoch:6, step:50, loss:0.4252064824104309\n",
|
||||
"epoch:6, step:60, loss:0.411870539188385\n",
|
||||
"epoch:6, step:70, loss:0.43613123893737793\n",
|
||||
"epoch:6, step:80, loss:0.4038943350315094\n",
|
||||
"epoch:6, step:90, loss:0.40738430619239807\n",
|
||||
"epoch:6, step:100, loss:0.41697797179222107\n",
|
||||
" precision recall f1-score support\n",
|
||||
"\n",
|
||||
" 0 0.76 0.82 0.79 155\n",
|
||||
" 1 0.92 0.88 0.90 345\n",
|
||||
"\n",
|
||||
" accuracy 0.86 500\n",
|
||||
" macro avg 0.84 0.85 0.84 500\n",
|
||||
"weighted avg 0.87 0.86 0.87 500\n",
|
||||
"\n",
|
||||
"准确率: 0.864\n",
|
||||
"AUC: 0.9266947171575501\n",
|
||||
"saved model: ./model/bert_dnn_6.model\n",
|
||||
"epoch:7, step:10, loss:0.4255238175392151\n",
|
||||
"epoch:7, step:20, loss:0.3951468765735626\n",
|
||||
"epoch:7, step:30, loss:0.41892367601394653\n",
|
||||
"epoch:7, step:40, loss:0.40587490797042847\n",
|
||||
"epoch:7, step:50, loss:0.3918803036212921\n",
|
||||
"epoch:7, step:60, loss:0.43665409088134766\n",
|
||||
"epoch:7, step:70, loss:0.4085603654384613\n",
|
||||
"epoch:7, step:80, loss:0.3877314627170563\n",
|
||||
"epoch:7, step:90, loss:0.3680875301361084\n",
|
||||
"epoch:7, step:100, loss:0.4211949408054352\n",
|
||||
" precision recall f1-score support\n",
|
||||
"\n",
|
||||
" 0 0.74 0.85 0.79 155\n",
|
||||
" 1 0.93 0.87 0.90 345\n",
|
||||
"\n",
|
||||
" accuracy 0.86 500\n",
|
||||
" macro avg 0.83 0.86 0.84 500\n",
|
||||
"weighted avg 0.87 0.86 0.86 500\n",
|
||||
"\n",
|
||||
"准确率: 0.86\n",
|
||||
"AUC: 0.9282094436652641\n",
|
||||
"saved model: ./model/bert_dnn_7.model\n",
|
||||
"epoch:8, step:10, loss:0.3657851815223694\n",
|
||||
"epoch:8, step:20, loss:0.3944622576236725\n",
|
||||
"epoch:8, step:30, loss:0.40657711029052734\n",
|
||||
"epoch:8, step:40, loss:0.3935934901237488\n",
|
||||
"epoch:8, step:50, loss:0.4171984791755676\n",
|
||||
"epoch:8, step:60, loss:0.4169773459434509\n",
|
||||
"epoch:8, step:70, loss:0.4021885395050049\n",
|
||||
"epoch:8, step:80, loss:0.4106557369232178\n",
|
||||
"epoch:8, step:100, loss:0.4116268754005432\n",
|
||||
" precision recall f1-score support\n",
|
||||
"\n",
|
||||
" 0 0.80 0.78 0.79 155\n",
|
||||
" 1 0.90 0.91 0.91 345\n",
|
||||
"\n",
|
||||
" accuracy 0.87 500\n",
|
||||
" macro avg 0.85 0.85 0.85 500\n",
|
||||
"weighted avg 0.87 0.87 0.87 500\n",
|
||||
"\n",
|
||||
"准确率: 0.87\n",
|
||||
"AUC: 0.9288078541374474\n",
|
||||
"saved model: ./model/bert_dnn_8.model\n",
|
||||
"epoch:9, step:10, loss:0.4415532052516937\n",
|
||||
"epoch:9, step:20, loss:0.4093624949455261\n",
|
||||
"epoch:9, step:30, loss:0.3825526833534241\n",
|
||||
"epoch:9, step:40, loss:0.3692132532596588\n",
|
||||
"epoch:9, step:50, loss:0.39409342408180237\n",
|
||||
"epoch:9, step:60, loss:0.40440621972084045\n",
|
||||
"epoch:9, step:70, loss:0.3859332203865051\n",
|
||||
"epoch:9, step:80, loss:0.40987101197242737\n",
|
||||
"epoch:9, step:90, loss:0.4061252176761627\n",
|
||||
"epoch:9, step:100, loss:0.4131951332092285\n",
|
||||
" precision recall f1-score support\n",
|
||||
"\n",
|
||||
" 0 0.75 0.84 0.79 155\n",
|
||||
" 1 0.92 0.88 0.90 345\n",
|
||||
"\n",
|
||||
" accuracy 0.86 500\n",
|
||||
" macro avg 0.84 0.86 0.85 500\n",
|
||||
"weighted avg 0.87 0.86 0.87 500\n",
|
||||
"\n",
|
||||
"准确率: 0.864\n",
|
||||
"AUC: 0.9293501636278635\n",
|
||||
"saved model: ./model/bert_dnn_9.model\n",
|
||||
"epoch:10, step:10, loss:0.40611615777015686\n",
|
||||
"epoch:10, step:20, loss:0.42403316497802734\n",
|
||||
"epoch:10, step:30, loss:0.3972412943840027\n",
|
||||
"epoch:10, step:40, loss:0.4144269526004791\n",
|
||||
"epoch:10, step:50, loss:0.37967294454574585\n",
|
||||
"epoch:10, step:60, loss:0.3992181420326233\n",
|
||||
"epoch:10, step:70, loss:0.3896545469760895\n",
|
||||
"epoch:10, step:80, loss:0.39779797196388245\n",
|
||||
"epoch:10, step:90, loss:0.38316115736961365\n",
|
||||
"epoch:10, step:100, loss:0.4042983055114746\n",
|
||||
" precision recall f1-score support\n",
|
||||
"\n",
|
||||
" 0 0.76 0.82 0.79 155\n",
|
||||
" 1 0.92 0.88 0.90 345\n",
|
||||
"\n",
|
||||
" accuracy 0.86 500\n",
|
||||
" macro avg 0.84 0.85 0.84 500\n",
|
||||
"weighted avg 0.87 0.86 0.86 500\n",
|
||||
"\n",
|
||||
"准确率: 0.862\n",
|
||||
"AUC: 0.9303973819541842\n",
|
||||
"saved model: ./model/bert_dnn_10.model\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# 迭代训练\n",
|
||||
"for epoch in range(num_epoches):\n",
|
||||
" total_loss = 0\n",
|
||||
" for i, (words, labels) in enumerate(train_loader):\n",
|
||||
" tokens = tokenizer(words, padding=True)\n",
|
||||
" input_ids = torch.tensor(tokens[\"input_ids\"]).to(device)\n",
|
||||
" attention_mask = torch.tensor(tokens[\"attention_mask\"]).to(device)\n",
|
||||
" labels = labels.float().to(device)\n",
|
||||
" with torch.no_grad():\n",
|
||||
" last_hidden_states = bert(input_ids, attention_mask=attention_mask)\n",
|
||||
" bert_output = last_hidden_states[0][:, 0]\n",
|
||||
" optimizer.zero_grad() # 梯度清零\n",
|
||||
" outputs = net(bert_output) # 前向传播\n",
|
||||
" logits = outputs.view(-1) # 将输出展平\n",
|
||||
" loss = criterion(logits, labels) # loss计算\n",
|
||||
" total_loss += loss\n",
|
||||
" loss.backward() # 反向传播,计算梯度\n",
|
||||
" optimizer.step() # 梯度更新\n",
|
||||
" if (i+1) % 10 == 0:\n",
|
||||
" print(\"epoch:{}, step:{}, loss:{}\".format(epoch+1, i+1, total_loss/10))\n",
|
||||
" total_loss = 0\n",
|
||||
" \n",
|
||||
" # learning_rate decay\n",
|
||||
" scheduler.step()\n",
|
||||
" \n",
|
||||
" # test\n",
|
||||
" test()\n",
|
||||
" \n",
|
||||
" # save model\n",
|
||||
" model_path = \"./model/bert_dnn_{}.model\".format(epoch+1)\n",
|
||||
" torch.save(net, model_path)\n",
|
||||
" print(\"saved model: \", model_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"cell_id": 23
|
||||
},
|
||||
"source": [
|
||||
"### 手动输入句子,判断情感倾向(1正/0负)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"metadata": {
|
||||
"cell_id": 38
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"net = torch.load(\"./model/bert_dnn_8.model\") # 训练过程中的巅峰时刻"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"metadata": {
|
||||
"cell_id": 37
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"tensor([[0.9007],\n",
|
||||
" [0.2211]], grad_fn=<SigmoidBackward>)"
|
||||
]
|
||||
},
|
||||
"execution_count": 15,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"s = [\"华丽繁荣的城市、充满回忆的小镇、郁郁葱葱的山谷...\", \"突然就觉得人间不值得\"]\n",
|
||||
"tokens = tokenizer(s, padding=True)\n",
|
||||
"input_ids = torch.tensor(tokens[\"input_ids\"])\n",
|
||||
"attention_mask = torch.tensor(tokens[\"attention_mask\"])\n",
|
||||
"last_hidden_states = bert(input_ids, attention_mask=attention_mask)\n",
|
||||
"bert_output = last_hidden_states[0][:, 0]\n",
|
||||
"outputs = net(bert_output)\n",
|
||||
"outputs"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"metadata": {
|
||||
"cell_id": 27,
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"tensor([[0.9735],\n",
|
||||
" [0.9882]], grad_fn=<SigmoidBackward>)"
|
||||
]
|
||||
},
|
||||
"execution_count": 21,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"s = [\"今天天气真好\", \"今天天气特别特别棒\"]\n",
|
||||
"tokens = tokenizer(s, padding=True)\n",
|
||||
"input_ids = torch.tensor(tokens[\"input_ids\"])\n",
|
||||
"attention_mask = torch.tensor(tokens[\"attention_mask\"])\n",
|
||||
"last_hidden_states = bert(input_ids, attention_mask=attention_mask)\n",
|
||||
"bert_output = last_hidden_states[0][:, 0]\n",
|
||||
"outputs = net(bert_output)\n",
|
||||
"outputs"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"cell_id": 32
|
||||
},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.5"
|
||||
},
|
||||
"max_cell_id": 45
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -1,32 +1,107 @@
|
||||
# WeiboSentiment
|
||||
用各种机器学习对中文微博进行情感分析
|
||||
语料来源: https://github.com/dengxiuqi/weibo2018
|
||||
---
|
||||
##### "微博情感分析"是我本科的毕业设计, 也是我入门NLP的项目, 就把它发出来供大家交流。
|
||||
##### 2021.06.07更新: 之前的版本写得比较随意, 没想到star破百了, 私下也有一些刚入门NLP的同学因为这个项目联系我, 就更新一下这个项目吧
|
||||
* 重构项目架构和代码, 提高可读性
|
||||
* 每个文件中的特征、数据处理方法与模型细节都尽可能避免重复, 以给各位同学提供更多的参考
|
||||
* 神经网络结构换成了pytorch, 需要`tensorflow 1.0`代码的同学请回退至`445998`版本。
|
||||
* 新增了`Bert`模型
|
||||
* 由于gensim新老版本很多语法不兼容, 将gensim更新为4.0版本
|
||||
----
|
||||
#### 项目说明
|
||||
* 训练集10000条语料, 测试集500条语料
|
||||
* 使用朴素贝叶斯、SVM、XGBoost、LSTM和Bert, 等多种模型搭建并训练二分类模型
|
||||
* 前3个模型都采用端到端的训练方法
|
||||
* LSTM先预训练得到Word2Vec词向量, 在训练神经网络
|
||||
* `Bert`使用的是哈工大的预训练模型, 用Bert的`[CLS]`位输出在一个下游网络上进行finetune。预训练模型需要自行下载:
|
||||
* github下载地址: https://github.com/ymcui/Chinese-BERT-wwm
|
||||
* baidu网盘: https://pan.baidu.com/s/16z-ybrqT6wLdy_mLHtywSw 密码: djkj
|
||||
* 下载后将文件夹放在`./model`文件夹下, 并将`bert_config.json`改名为`config.json`
|
||||
---
|
||||
#### 实验结果
|
||||
各种分类器在测试集上的测试结果
|
||||
# 微博情感分析 - 传统机器学习方法
|
||||
|
||||
|模型|准确率|AUC|
|
||||
| :---: | :---: | :---: |
|
||||
|1.bayes|0.856| - |
|
||||
|2.svm|0.856| - |
|
||||
|3.xgboost|0.86| 0.904 |
|
||||
|4.lstm|0.87| 0.931 |
|
||||
|5.bert|0.87| 0.929 |
|
||||
## 项目介绍
|
||||
|
||||
本项目使用5种传统机器学习方法对中文微博进行情感二分类(正面/负面):
|
||||
|
||||
- **朴素贝叶斯**: 基于词袋模型的概率分类
|
||||
- **SVM**: 基于TF-IDF特征的支持向量机
|
||||
- **XGBoost**: 梯度提升决策树
|
||||
- **LSTM**: 循环神经网络 + Word2Vec词向量
|
||||
- **BERT+分类头**: 预训练语言模型接分类器(我认为也属于传统ML范畴)
|
||||
|
||||
## 模型性能
|
||||
|
||||
在微博情感数据集上的表现(训练集10000条,测试集500条):
|
||||
|
||||
| 模型 | 准确率 | AUC | 特点 |
|
||||
|------|--------|-----|------|
|
||||
| 朴素贝叶斯 | 85.6% | - | 速度快,内存占用小 |
|
||||
| SVM | 85.6% | - | 泛化能力好 |
|
||||
| XGBoost | 86.0% | 90.4% | 性能稳定,支持特征重要性 |
|
||||
| LSTM | 87.0% | 93.1% | 理解序列信息和上下文 |
|
||||
| BERT+分类头 | 87.0% | 92.9% | 强大的语义理解能力 |
|
||||
|
||||
## 环境配置
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
数据文件结构:
|
||||
```
|
||||
data/
|
||||
├── weibo2018/
|
||||
│ ├── train.txt
|
||||
│ └── test.txt
|
||||
└── stopwords.txt
|
||||
```
|
||||
|
||||
## 训练模型(后面可以不接参数直接运行)
|
||||
|
||||
### 朴素贝叶斯
|
||||
```bash
|
||||
python bayes_train.py
|
||||
```
|
||||
|
||||
### SVM
|
||||
```bash
|
||||
python svm_train.py --kernel rbf --C 1.0
|
||||
```
|
||||
|
||||
### XGBoost
|
||||
```bash
|
||||
python xgboost_train.py --max_depth 6 --eta 0.3 --num_boost_round 200
|
||||
```
|
||||
|
||||
### LSTM
|
||||
```bash
|
||||
python lstm_train.py --epochs 5 --batch_size 100 --hidden_size 64
|
||||
```
|
||||
|
||||
### BERT
|
||||
```bash
|
||||
python bert_train.py --epochs 10 --batch_size 100 --learning_rate 1e-3
|
||||
```
|
||||
|
||||
注:BERT模型会自动下载中文预训练模型(bert-base-chinese)
|
||||
|
||||
## 使用预测
|
||||
|
||||
### 交互式预测(推荐)
|
||||
```bash
|
||||
python predict.py
|
||||
```
|
||||
|
||||
### 命令行预测
|
||||
```bash
|
||||
# 单模型预测
|
||||
python predict.py --model_type bert --text "今天天气真好,心情很棒"
|
||||
|
||||
# 多模型集成预测
|
||||
python predict.py --ensemble --text "这部电影太无聊了"
|
||||
```
|
||||
|
||||
## 文件结构
|
||||
|
||||
```
|
||||
WeiboSentiment_MachineLearning/
|
||||
├── bayes_train.py # 朴素贝叶斯训练
|
||||
├── svm_train.py # SVM训练
|
||||
├── xgboost_train.py # XGBoost训练
|
||||
├── lstm_train.py # LSTM训练
|
||||
├── bert_train.py # BERT训练
|
||||
├── predict.py # 统一预测程序
|
||||
├── base_model.py # 基础模型类
|
||||
├── utils.py # 工具函数
|
||||
├── requirements.txt # 依赖包
|
||||
├── model/ # 模型保存目录
|
||||
└── data/ # 数据目录
|
||||
```
|
||||
|
||||
## 注意事项
|
||||
|
||||
1. **BERT模型**首次运行会自动下载预训练模型(约400MB)
|
||||
2. **LSTM模型**训练时间较长,建议使用GPU
|
||||
3. **模型保存**在 `model/` 目录下,确保有足够磁盘空间
|
||||
4. **内存需求**BERT > LSTM > XGBoost > SVM > 朴素贝叶斯
|
||||
|
||||
@@ -0,0 +1,120 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
基础模型类,为所有情感分析模型提供统一接口
|
||||
"""
|
||||
import os
|
||||
import pickle
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Tuple, Dict, Any
|
||||
import pandas as pd
|
||||
from sklearn.metrics import accuracy_score, f1_score, classification_report
|
||||
from utils import load_corpus
|
||||
|
||||
|
||||
class BaseModel(ABC):
|
||||
"""情感分析模型基类"""
|
||||
|
||||
def __init__(self, model_name: str):
|
||||
self.model_name = model_name
|
||||
self.model = None
|
||||
self.vectorizer = None
|
||||
self.is_trained = False
|
||||
|
||||
@abstractmethod
|
||||
def train(self, train_data: List[Tuple[str, int]], **kwargs) -> None:
|
||||
"""训练模型"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def predict(self, texts: List[str]) -> List[int]:
|
||||
"""预测文本情感"""
|
||||
pass
|
||||
|
||||
def predict_single(self, text: str) -> Tuple[int, float]:
|
||||
"""预测单条文本的情感
|
||||
|
||||
Args:
|
||||
text: 待预测文本
|
||||
|
||||
Returns:
|
||||
(predicted_label, confidence)
|
||||
"""
|
||||
predictions = self.predict([text])
|
||||
return predictions[0], 0.0 # 默认置信度为0
|
||||
|
||||
def evaluate(self, test_data: List[Tuple[str, int]]) -> Dict[str, float]:
|
||||
"""评估模型性能"""
|
||||
if not self.is_trained:
|
||||
raise ValueError(f"模型 {self.model_name} 尚未训练,请先调用train方法")
|
||||
|
||||
texts = [item[0] for item in test_data]
|
||||
labels = [item[1] for item in test_data]
|
||||
|
||||
predictions = self.predict(texts)
|
||||
|
||||
accuracy = accuracy_score(labels, predictions)
|
||||
f1 = f1_score(labels, predictions, average='weighted')
|
||||
|
||||
print(f"\n{self.model_name} 模型评估结果:")
|
||||
print(f"准确率: {accuracy:.4f}")
|
||||
print(f"F1分数: {f1:.4f}")
|
||||
print("\n详细报告:")
|
||||
print(classification_report(labels, predictions))
|
||||
|
||||
return {
|
||||
'accuracy': accuracy,
|
||||
'f1_score': f1,
|
||||
'classification_report': classification_report(labels, predictions)
|
||||
}
|
||||
|
||||
def save_model(self, model_path: str = None) -> None:
|
||||
"""保存模型到文件"""
|
||||
if not self.is_trained:
|
||||
raise ValueError(f"模型 {self.model_name} 尚未训练,无法保存")
|
||||
|
||||
if model_path is None:
|
||||
model_path = f"model/{self.model_name}_model.pkl"
|
||||
|
||||
# 创建保存目录
|
||||
os.makedirs(os.path.dirname(model_path), exist_ok=True)
|
||||
|
||||
# 保存模型数据
|
||||
model_data = {
|
||||
'model': self.model,
|
||||
'vectorizer': self.vectorizer,
|
||||
'model_name': self.model_name,
|
||||
'is_trained': self.is_trained
|
||||
}
|
||||
|
||||
with open(model_path, 'wb') as f:
|
||||
pickle.dump(model_data, f)
|
||||
|
||||
print(f"模型已保存到: {model_path}")
|
||||
|
||||
def load_model(self, model_path: str) -> None:
|
||||
"""从文件加载模型"""
|
||||
if not os.path.exists(model_path):
|
||||
raise FileNotFoundError(f"模型文件不存在: {model_path}")
|
||||
|
||||
with open(model_path, 'rb') as f:
|
||||
model_data = pickle.load(f)
|
||||
|
||||
self.model = model_data['model']
|
||||
self.vectorizer = model_data.get('vectorizer')
|
||||
self.model_name = model_data['model_name']
|
||||
self.is_trained = model_data['is_trained']
|
||||
|
||||
print(f"已加载模型: {model_path}")
|
||||
|
||||
@staticmethod
|
||||
def load_data(train_path: str, test_path: str) -> Tuple[List[Tuple[str, int]], List[Tuple[str, int]]]:
|
||||
"""加载训练和测试数据"""
|
||||
print("加载训练数据...")
|
||||
train_data = load_corpus(train_path)
|
||||
print(f"训练数据量: {len(train_data)}")
|
||||
|
||||
print("加载测试数据...")
|
||||
test_data = load_corpus(test_path)
|
||||
print(f"测试数据量: {len(test_data)}")
|
||||
|
||||
return train_data, test_data
|
||||
@@ -0,0 +1,155 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
朴素贝叶斯情感分析模型训练脚本
|
||||
"""
|
||||
import argparse
|
||||
import pandas as pd
|
||||
from typing import List, Tuple
|
||||
from sklearn.feature_extraction.text import CountVectorizer
|
||||
from sklearn.naive_bayes import MultinomialNB
|
||||
from sklearn.metrics import accuracy_score, f1_score
|
||||
|
||||
from base_model import BaseModel
|
||||
from utils import stopwords
|
||||
|
||||
|
||||
class BayesModel(BaseModel):
|
||||
"""朴素贝叶斯情感分析模型"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("Bayes")
|
||||
|
||||
def train(self, train_data: List[Tuple[str, int]], **kwargs) -> None:
|
||||
"""训练朴素贝叶斯模型
|
||||
|
||||
Args:
|
||||
train_data: 训练数据,格式为[(text, label), ...]
|
||||
**kwargs: 其他参数
|
||||
"""
|
||||
print(f"开始训练 {self.model_name} 模型...")
|
||||
|
||||
# 准备数据
|
||||
df_train = pd.DataFrame(train_data, columns=["words", "label"])
|
||||
|
||||
# 特征编码(词袋模型)
|
||||
print("构建词袋模型...")
|
||||
self.vectorizer = CountVectorizer(
|
||||
token_pattern=r'\[?\w+\]?',
|
||||
stop_words=stopwords
|
||||
)
|
||||
|
||||
X_train = self.vectorizer.fit_transform(df_train["words"])
|
||||
y_train = df_train["label"]
|
||||
|
||||
print(f"特征维度: {X_train.shape[1]}")
|
||||
|
||||
# 训练模型
|
||||
print("训练朴素贝叶斯分类器...")
|
||||
self.model = MultinomialNB()
|
||||
self.model.fit(X_train, y_train)
|
||||
|
||||
self.is_trained = True
|
||||
print(f"{self.model_name} 模型训练完成!")
|
||||
|
||||
def predict(self, texts: List[str]) -> List[int]:
|
||||
"""预测文本情感
|
||||
|
||||
Args:
|
||||
texts: 待预测文本列表
|
||||
|
||||
Returns:
|
||||
预测结果列表
|
||||
"""
|
||||
if not self.is_trained:
|
||||
raise ValueError(f"模型 {self.model_name} 尚未训练,请先调用train方法")
|
||||
|
||||
# 特征转换
|
||||
X = self.vectorizer.transform(texts)
|
||||
|
||||
# 预测
|
||||
predictions = self.model.predict(X)
|
||||
|
||||
return predictions.tolist()
|
||||
|
||||
def predict_single(self, text: str) -> Tuple[int, float]:
|
||||
"""预测单条文本的情感
|
||||
|
||||
Args:
|
||||
text: 待预测文本
|
||||
|
||||
Returns:
|
||||
(predicted_label, confidence)
|
||||
"""
|
||||
if not self.is_trained:
|
||||
raise ValueError(f"模型 {self.model_name} 尚未训练,请先调用train方法")
|
||||
|
||||
# 特征转换
|
||||
X = self.vectorizer.transform([text])
|
||||
|
||||
# 预测
|
||||
prediction = self.model.predict(X)[0]
|
||||
probabilities = self.model.predict_proba(X)[0]
|
||||
confidence = max(probabilities)
|
||||
|
||||
return int(prediction), float(confidence)
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
parser = argparse.ArgumentParser(description='朴素贝叶斯情感分析模型训练')
|
||||
parser.add_argument('--train_path', type=str, default='./data/weibo2018/train.txt',
|
||||
help='训练数据路径')
|
||||
parser.add_argument('--test_path', type=str, default='./data/weibo2018/test.txt',
|
||||
help='测试数据路径')
|
||||
parser.add_argument('--model_path', type=str, default='./model/bayes_model.pkl',
|
||||
help='模型保存路径')
|
||||
parser.add_argument('--eval_only', action='store_true',
|
||||
help='仅评估已有模型,不进行训练')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 创建模型
|
||||
model = BayesModel()
|
||||
|
||||
if args.eval_only:
|
||||
# 仅评估模式
|
||||
print("评估模式:加载已有模型进行评估")
|
||||
model.load_model(args.model_path)
|
||||
|
||||
# 加载测试数据
|
||||
_, test_data = BaseModel.load_data(args.train_path, args.test_path)
|
||||
|
||||
# 评估模型
|
||||
model.evaluate(test_data)
|
||||
else:
|
||||
# 训练模式
|
||||
# 加载数据
|
||||
train_data, test_data = BaseModel.load_data(args.train_path, args.test_path)
|
||||
|
||||
# 训练模型
|
||||
model.train(train_data)
|
||||
|
||||
# 评估模型
|
||||
model.evaluate(test_data)
|
||||
|
||||
# 保存模型
|
||||
model.save_model(args.model_path)
|
||||
|
||||
# 示例预测
|
||||
print("\n示例预测:")
|
||||
test_texts = [
|
||||
"今天天气真好,心情很棒",
|
||||
"这部电影太无聊了,浪费时间",
|
||||
"哈哈哈,太有趣了"
|
||||
]
|
||||
|
||||
for text in test_texts:
|
||||
pred, conf = model.predict_single(text)
|
||||
sentiment = "正面" if pred == 1 else "负面"
|
||||
print(f"文本: {text}")
|
||||
print(f"预测: {sentiment} (置信度: {conf:.4f})")
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,413 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
BERT情感分析模型训练脚本
|
||||
"""
|
||||
import argparse
|
||||
import os
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from transformers import BertTokenizer, BertModel
|
||||
from sklearn.metrics import accuracy_score, f1_score, classification_report, roc_auc_score
|
||||
from typing import List, Tuple
|
||||
import warnings
|
||||
import requests
|
||||
from pathlib import Path
|
||||
|
||||
from base_model import BaseModel
|
||||
from utils import load_corpus_bert
|
||||
|
||||
# 忽略transformers的警告
|
||||
warnings.filterwarnings("ignore")
|
||||
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
|
||||
|
||||
|
||||
class BertDataset(Dataset):
|
||||
"""BERT数据集"""
|
||||
|
||||
def __init__(self, data: List[Tuple[str, int]]):
|
||||
self.data = [item[0] for item in data]
|
||||
self.labels = [item[1] for item in data]
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.data[index], self.labels[index]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.labels)
|
||||
|
||||
|
||||
class BertClassifier(nn.Module):
|
||||
"""BERT分类器网络"""
|
||||
|
||||
def __init__(self, input_size):
|
||||
super(BertClassifier, self).__init__()
|
||||
self.fc = nn.Linear(input_size, 1)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
out = self.fc(x)
|
||||
out = self.sigmoid(out)
|
||||
return out
|
||||
|
||||
|
||||
class BertModel_Custom(BaseModel):
|
||||
"""BERT情感分析模型"""
|
||||
|
||||
def __init__(self, model_path: str = "./model/chinese_wwm_pytorch"):
|
||||
super().__init__("BERT")
|
||||
self.model_path = model_path
|
||||
self.tokenizer = None
|
||||
self.bert = None
|
||||
self.classifier = None
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
def _download_bert_model(self):
|
||||
"""自动下载BERT预训练模型"""
|
||||
print(f"BERT模型不存在,正在下载中文BERT预训练模型...")
|
||||
print("下载来源: bert-base-chinese (Hugging Face)")
|
||||
|
||||
try:
|
||||
# 创建模型目录
|
||||
os.makedirs(self.model_path, exist_ok=True)
|
||||
|
||||
# 使用Hugging Face的中文BERT模型
|
||||
model_name = "bert-base-chinese"
|
||||
print(f"正在从Hugging Face下载 {model_name}...")
|
||||
|
||||
# 下载tokenizer
|
||||
print("下载分词器...")
|
||||
tokenizer = BertTokenizer.from_pretrained(model_name)
|
||||
tokenizer.save_pretrained(self.model_path)
|
||||
|
||||
# 下载模型
|
||||
print("下载BERT模型...")
|
||||
bert_model = BertModel.from_pretrained(model_name)
|
||||
bert_model.save_pretrained(self.model_path)
|
||||
|
||||
print(f"✅ BERT模型下载完成,保存在: {self.model_path}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ BERT模型下载失败: {e}")
|
||||
print("\n💡 您可以手动下载BERT模型:")
|
||||
print("1. 访问 https://huggingface.co/bert-base-chinese")
|
||||
print("2. 或使用哈工大中文BERT: https://github.com/ymcui/Chinese-BERT-wwm")
|
||||
print(f"3. 将模型文件解压到: {self.model_path}")
|
||||
return False
|
||||
|
||||
def _load_bert(self):
|
||||
"""加载BERT模型和分词器"""
|
||||
print(f"加载BERT模型: {self.model_path}")
|
||||
|
||||
# 如果模型不存在,尝试自动下载
|
||||
if not os.path.exists(self.model_path) or not any(os.scandir(self.model_path)):
|
||||
print("BERT模型不存在,尝试自动下载...")
|
||||
if not self._download_bert_model():
|
||||
raise FileNotFoundError(f"BERT模型下载失败,请手动下载到: {self.model_path}")
|
||||
|
||||
try:
|
||||
self.tokenizer = BertTokenizer.from_pretrained(self.model_path)
|
||||
self.bert = BertModel.from_pretrained(self.model_path).to(self.device)
|
||||
|
||||
# 冻结BERT参数
|
||||
for param in self.bert.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
print("✅ BERT模型加载完成")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ BERT模型加载失败: {e}")
|
||||
print("尝试使用在线模型...")
|
||||
|
||||
# 如果本地加载失败,尝试直接使用在线模型
|
||||
try:
|
||||
model_name = "bert-base-chinese"
|
||||
self.tokenizer = BertTokenizer.from_pretrained(model_name)
|
||||
self.bert = BertModel.from_pretrained(model_name).to(self.device)
|
||||
|
||||
# 冻结BERT参数
|
||||
for param in self.bert.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
print("✅ 在线BERT模型加载完成")
|
||||
|
||||
except Exception as e2:
|
||||
print(f"❌ 在线模型也加载失败: {e2}")
|
||||
raise FileNotFoundError(f"无法加载BERT模型,请检查网络连接或手动下载模型到: {self.model_path}")
|
||||
|
||||
def train(self, train_data: List[Tuple[str, int]], **kwargs) -> None:
|
||||
"""训练BERT模型"""
|
||||
print(f"开始训练 {self.model_name} 模型...")
|
||||
|
||||
# 加载BERT
|
||||
self._load_bert()
|
||||
|
||||
# 超参数
|
||||
learning_rate = kwargs.get('learning_rate', 1e-3)
|
||||
num_epochs = kwargs.get('num_epochs', 10)
|
||||
batch_size = kwargs.get('batch_size', 100)
|
||||
input_size = kwargs.get('input_size', 768)
|
||||
decay_rate = kwargs.get('decay_rate', 0.9)
|
||||
|
||||
print(f"BERT超参数: lr={learning_rate}, epochs={num_epochs}, "
|
||||
f"batch_size={batch_size}, input_size={input_size}")
|
||||
|
||||
# 创建数据集
|
||||
train_dataset = BertDataset(train_data)
|
||||
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
||||
|
||||
# 创建分类器
|
||||
self.classifier = BertClassifier(input_size).to(self.device)
|
||||
|
||||
# 损失函数和优化器
|
||||
criterion = nn.BCELoss()
|
||||
optimizer = torch.optim.Adam(self.classifier.parameters(), lr=learning_rate)
|
||||
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=decay_rate)
|
||||
|
||||
# 训练循环
|
||||
self.bert.eval() # BERT始终保持评估模式
|
||||
self.classifier.train()
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
total_loss = 0
|
||||
num_batches = 0
|
||||
|
||||
for i, (words, labels) in enumerate(train_loader):
|
||||
# 分词和编码
|
||||
tokens = self.tokenizer(words, padding=True, truncation=True,
|
||||
max_length=512, return_tensors='pt')
|
||||
input_ids = tokens["input_ids"].to(self.device)
|
||||
attention_mask = tokens["attention_mask"].to(self.device)
|
||||
labels = torch.tensor(labels, dtype=torch.float32).to(self.device)
|
||||
|
||||
# 获取BERT输出(冻结参数)
|
||||
with torch.no_grad():
|
||||
bert_outputs = self.bert(input_ids, attention_mask=attention_mask)
|
||||
bert_output = bert_outputs[0][:, 0] # [CLS] token的输出
|
||||
|
||||
# 分类器前向传播
|
||||
optimizer.zero_grad()
|
||||
outputs = self.classifier(bert_output)
|
||||
logits = outputs.view(-1)
|
||||
loss = criterion(logits, labels)
|
||||
|
||||
# 反向传播
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
total_loss += loss.item()
|
||||
num_batches += 1
|
||||
|
||||
if (i + 1) % 10 == 0:
|
||||
avg_loss = total_loss / num_batches
|
||||
print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}], Loss: {avg_loss:.4f}")
|
||||
total_loss = 0
|
||||
num_batches = 0
|
||||
|
||||
# 学习率衰减
|
||||
scheduler.step()
|
||||
|
||||
# 保存每个epoch的模型
|
||||
if kwargs.get('save_each_epoch', False):
|
||||
epoch_model_path = f"./model/bert_epoch_{epoch+1}.pth"
|
||||
os.makedirs(os.path.dirname(epoch_model_path), exist_ok=True)
|
||||
torch.save(self.classifier.state_dict(), epoch_model_path)
|
||||
print(f"已保存模型: {epoch_model_path}")
|
||||
|
||||
self.is_trained = True
|
||||
print(f"{self.model_name} 模型训练完成!")
|
||||
|
||||
def predict(self, texts: List[str]) -> List[int]:
|
||||
"""预测文本情感"""
|
||||
if not self.is_trained:
|
||||
raise ValueError(f"模型 {self.model_name} 尚未训练,请先调用train方法")
|
||||
|
||||
predictions = []
|
||||
batch_size = 32
|
||||
|
||||
self.bert.eval()
|
||||
self.classifier.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch_texts = texts[i:i+batch_size]
|
||||
|
||||
# 分词和编码
|
||||
tokens = self.tokenizer(batch_texts, padding=True, truncation=True,
|
||||
max_length=512, return_tensors='pt')
|
||||
input_ids = tokens["input_ids"].to(self.device)
|
||||
attention_mask = tokens["attention_mask"].to(self.device)
|
||||
|
||||
# 获取BERT输出
|
||||
bert_outputs = self.bert(input_ids, attention_mask=attention_mask)
|
||||
bert_output = bert_outputs[0][:, 0]
|
||||
|
||||
# 分类器预测
|
||||
outputs = self.classifier(bert_output)
|
||||
outputs = outputs.view(-1)
|
||||
|
||||
# 转换为类别标签
|
||||
preds = (outputs > 0.5).cpu().numpy()
|
||||
predictions.extend(preds.astype(int).tolist())
|
||||
|
||||
return predictions
|
||||
|
||||
def predict_single(self, text: str) -> Tuple[int, float]:
|
||||
"""预测单条文本的情感"""
|
||||
if not self.is_trained:
|
||||
raise ValueError(f"模型 {self.model_name} 尚未训练,请先调用train方法")
|
||||
|
||||
self.bert.eval()
|
||||
self.classifier.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
# 分词和编码
|
||||
tokens = self.tokenizer([text], padding=True, truncation=True,
|
||||
max_length=512, return_tensors='pt')
|
||||
input_ids = tokens["input_ids"].to(self.device)
|
||||
attention_mask = tokens["attention_mask"].to(self.device)
|
||||
|
||||
# 获取BERT输出
|
||||
bert_outputs = self.bert(input_ids, attention_mask=attention_mask)
|
||||
bert_output = bert_outputs[0][:, 0]
|
||||
|
||||
# 分类器预测
|
||||
output = self.classifier(bert_output)
|
||||
prob = output.item()
|
||||
|
||||
prediction = int(prob > 0.5)
|
||||
confidence = prob if prediction == 1 else 1 - prob
|
||||
|
||||
return prediction, confidence
|
||||
|
||||
def save_model(self, model_path: str = None) -> None:
|
||||
"""保存模型"""
|
||||
if not self.is_trained:
|
||||
raise ValueError(f"模型 {self.model_name} 尚未训练,无法保存")
|
||||
|
||||
if model_path is None:
|
||||
model_path = f"./model/{self.model_name.lower()}_model.pth"
|
||||
|
||||
os.makedirs(os.path.dirname(model_path), exist_ok=True)
|
||||
|
||||
# 保存分类器和相关信息
|
||||
model_data = {
|
||||
'classifier_state_dict': self.classifier.state_dict(),
|
||||
'model_path': self.model_path,
|
||||
'input_size': 768,
|
||||
'device': str(self.device)
|
||||
}
|
||||
|
||||
torch.save(model_data, model_path)
|
||||
print(f"模型已保存到: {model_path}")
|
||||
|
||||
def load_model(self, model_path: str) -> None:
|
||||
"""加载模型"""
|
||||
if not os.path.exists(model_path):
|
||||
raise FileNotFoundError(f"模型文件不存在: {model_path}")
|
||||
|
||||
model_data = torch.load(model_path, map_location=self.device)
|
||||
|
||||
# 设置BERT模型路径
|
||||
self.model_path = model_data['model_path']
|
||||
|
||||
# 加载BERT
|
||||
self._load_bert()
|
||||
|
||||
# 重建分类器
|
||||
input_size = model_data['input_size']
|
||||
self.classifier = BertClassifier(input_size).to(self.device)
|
||||
|
||||
# 加载分类器权重
|
||||
self.classifier.load_state_dict(model_data['classifier_state_dict'])
|
||||
|
||||
self.is_trained = True
|
||||
print(f"已加载模型: {model_path}")
|
||||
|
||||
@staticmethod
|
||||
def load_data(train_path: str, test_path: str) -> Tuple[List[Tuple[str, int]], List[Tuple[str, int]]]:
|
||||
"""加载BERT格式的数据"""
|
||||
print("加载训练数据...")
|
||||
train_data = load_corpus_bert(train_path)
|
||||
print(f"训练数据量: {len(train_data)}")
|
||||
|
||||
print("加载测试数据...")
|
||||
test_data = load_corpus_bert(test_path)
|
||||
print(f"测试数据量: {len(test_data)}")
|
||||
|
||||
return train_data, test_data
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
parser = argparse.ArgumentParser(description='BERT情感分析模型训练')
|
||||
parser.add_argument('--train_path', type=str, default='./data/weibo2018/train.txt',
|
||||
help='训练数据路径')
|
||||
parser.add_argument('--test_path', type=str, default='./data/weibo2018/test.txt',
|
||||
help='测试数据路径')
|
||||
parser.add_argument('--model_path', type=str, default='./model/bert_model.pth',
|
||||
help='模型保存路径')
|
||||
parser.add_argument('--bert_path', type=str, default='./model/chinese_wwm_pytorch',
|
||||
help='BERT预训练模型路径')
|
||||
parser.add_argument('--epochs', type=int, default=10,
|
||||
help='训练轮数')
|
||||
parser.add_argument('--batch_size', type=int, default=100,
|
||||
help='批大小')
|
||||
parser.add_argument('--learning_rate', type=float, default=1e-3,
|
||||
help='学习率')
|
||||
parser.add_argument('--eval_only', action='store_true',
|
||||
help='仅评估已有模型,不进行训练')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 创建模型
|
||||
model = BertModel_Custom(args.bert_path)
|
||||
|
||||
if args.eval_only:
|
||||
# 仅评估模式
|
||||
print("评估模式:加载已有模型进行评估")
|
||||
model.load_model(args.model_path)
|
||||
|
||||
# 加载测试数据
|
||||
_, test_data = model.load_data(args.train_path, args.test_path)
|
||||
|
||||
# 评估模型
|
||||
model.evaluate(test_data)
|
||||
else:
|
||||
# 训练模式
|
||||
# 加载数据
|
||||
train_data, test_data = model.load_data(args.train_path, args.test_path)
|
||||
|
||||
# 训练模型
|
||||
model.train(
|
||||
train_data,
|
||||
num_epochs=args.epochs,
|
||||
batch_size=args.batch_size,
|
||||
learning_rate=args.learning_rate
|
||||
)
|
||||
|
||||
# 评估模型
|
||||
model.evaluate(test_data)
|
||||
|
||||
# 保存模型
|
||||
model.save_model(args.model_path)
|
||||
|
||||
# 示例预测
|
||||
print("\n示例预测:")
|
||||
test_texts = [
|
||||
"今天天气真好,心情很棒",
|
||||
"这部电影太无聊了,浪费时间",
|
||||
"哈哈哈,太有趣了"
|
||||
]
|
||||
|
||||
for text in test_texts:
|
||||
pred, conf = model.predict_single(text)
|
||||
sentiment = "正面" if pred == 1 else "负面"
|
||||
print(f"文本: {text}")
|
||||
print(f"预测: {sentiment} (置信度: {conf:.4f})")
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,352 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
LSTM情感分析模型训练脚本
|
||||
"""
|
||||
import argparse
|
||||
import os
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence
|
||||
from gensim import models
|
||||
from sklearn.metrics import accuracy_score, f1_score, classification_report, roc_auc_score
|
||||
from typing import List, Tuple, Dict, Any
|
||||
import numpy as np
|
||||
|
||||
from base_model import BaseModel
|
||||
|
||||
|
||||
class LSTMDataset(Dataset):
|
||||
"""LSTM数据集"""
|
||||
|
||||
def __init__(self, data: List[Tuple[str, int]], word2vec_model):
|
||||
self.data = []
|
||||
self.label = []
|
||||
|
||||
for text, label in data:
|
||||
vectors = []
|
||||
for word in text.split(" "):
|
||||
if word in word2vec_model.wv.key_to_index:
|
||||
vectors.append(word2vec_model.wv[word])
|
||||
|
||||
if len(vectors) > 0: # 确保有有效的词向量
|
||||
vectors = torch.Tensor(vectors)
|
||||
self.data.append(vectors)
|
||||
self.label.append(label)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.data[index], self.label[index]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.label)
|
||||
|
||||
|
||||
def collate_fn(data):
|
||||
"""批处理函数"""
|
||||
data.sort(key=lambda x: len(x[0]), reverse=True)
|
||||
data_length = [len(sq[0]) for sq in data]
|
||||
x = [i[0] for i in data]
|
||||
y = [i[1] for i in data]
|
||||
data = pad_sequence(x, batch_first=True, padding_value=0)
|
||||
return data, torch.tensor(y, dtype=torch.float32), data_length
|
||||
|
||||
|
||||
class LSTMNet(nn.Module):
|
||||
"""LSTM网络结构"""
|
||||
|
||||
def __init__(self, input_size, hidden_size, num_layers):
|
||||
super(LSTMNet, self).__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.num_layers = num_layers
|
||||
self.lstm = nn.LSTM(input_size, hidden_size, num_layers,
|
||||
batch_first=True, bidirectional=True)
|
||||
self.fc = nn.Linear(hidden_size * 2, 1) # 双向LSTM
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
def forward(self, x, lengths):
|
||||
device = x.device
|
||||
h0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(device)
|
||||
c0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(device)
|
||||
|
||||
packed_input = pack_padded_sequence(input=x, lengths=lengths, batch_first=True)
|
||||
packed_out, (h_n, h_c) = self.lstm(packed_input, (h0, c0))
|
||||
|
||||
# 双向LSTM,拼接最后的隐藏状态
|
||||
lstm_out = torch.cat([h_n[-2], h_n[-1]], 1)
|
||||
out = self.fc(lstm_out)
|
||||
out = self.sigmoid(out)
|
||||
return out
|
||||
|
||||
|
||||
class LSTMModel(BaseModel):
|
||||
"""LSTM情感分析模型"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("LSTM")
|
||||
self.word2vec_model = None
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
def _train_word2vec(self, train_data: List[Tuple[str, int]], **kwargs):
|
||||
"""训练Word2Vec词向量"""
|
||||
print("训练Word2Vec词向量...")
|
||||
|
||||
# 准备Word2Vec输入数据
|
||||
wv_input = [text.split(" ") for text, _ in train_data]
|
||||
|
||||
vector_size = kwargs.get('vector_size', 64)
|
||||
min_count = kwargs.get('min_count', 1)
|
||||
epochs = kwargs.get('epochs', 1000)
|
||||
|
||||
# 训练Word2Vec
|
||||
self.word2vec_model = models.Word2Vec(
|
||||
wv_input,
|
||||
vector_size=vector_size,
|
||||
min_count=min_count,
|
||||
epochs=epochs
|
||||
)
|
||||
|
||||
print(f"Word2Vec训练完成,词向量维度: {vector_size}")
|
||||
|
||||
def train(self, train_data: List[Tuple[str, int]], **kwargs) -> None:
|
||||
"""训练LSTM模型"""
|
||||
print(f"开始训练 {self.model_name} 模型...")
|
||||
|
||||
# 训练Word2Vec
|
||||
self._train_word2vec(train_data, **kwargs)
|
||||
|
||||
# 超参数
|
||||
learning_rate = kwargs.get('learning_rate', 5e-4)
|
||||
num_epochs = kwargs.get('num_epochs', 5)
|
||||
batch_size = kwargs.get('batch_size', 100)
|
||||
embed_size = kwargs.get('embed_size', 64)
|
||||
hidden_size = kwargs.get('hidden_size', 64)
|
||||
num_layers = kwargs.get('num_layers', 2)
|
||||
|
||||
print(f"LSTM超参数: lr={learning_rate}, epochs={num_epochs}, "
|
||||
f"batch_size={batch_size}, hidden_size={hidden_size}")
|
||||
|
||||
# 创建数据集
|
||||
train_dataset = LSTMDataset(train_data, self.word2vec_model)
|
||||
train_loader = DataLoader(train_dataset, batch_size=batch_size,
|
||||
collate_fn=collate_fn, shuffle=True)
|
||||
|
||||
# 创建模型
|
||||
self.model = LSTMNet(embed_size, hidden_size, num_layers).to(self.device)
|
||||
|
||||
# 损失函数和优化器
|
||||
criterion = nn.BCELoss()
|
||||
optimizer = torch.optim.Adam(self.model.parameters(), lr=learning_rate)
|
||||
|
||||
# 训练循环
|
||||
self.model.train()
|
||||
for epoch in range(num_epochs):
|
||||
total_loss = 0
|
||||
num_batches = 0
|
||||
|
||||
for i, (x, labels, lengths) in enumerate(train_loader):
|
||||
x = x.to(self.device)
|
||||
labels = labels.to(self.device)
|
||||
|
||||
# 前向传播
|
||||
outputs = self.model(x, lengths)
|
||||
logits = outputs.view(-1)
|
||||
loss = criterion(logits, labels)
|
||||
|
||||
# 反向传播
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
total_loss += loss.item()
|
||||
num_batches += 1
|
||||
|
||||
if (i + 1) % 10 == 0:
|
||||
avg_loss = total_loss / num_batches
|
||||
print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}], Loss: {avg_loss:.4f}")
|
||||
|
||||
# 保存每个epoch的模型
|
||||
if kwargs.get('save_each_epoch', False):
|
||||
epoch_model_path = f"./model/lstm_epoch_{epoch+1}.pth"
|
||||
os.makedirs(os.path.dirname(epoch_model_path), exist_ok=True)
|
||||
torch.save(self.model.state_dict(), epoch_model_path)
|
||||
print(f"已保存模型: {epoch_model_path}")
|
||||
|
||||
self.is_trained = True
|
||||
print(f"{self.model_name} 模型训练完成!")
|
||||
|
||||
def predict(self, texts: List[str]) -> List[int]:
|
||||
"""预测文本情感"""
|
||||
if not self.is_trained:
|
||||
raise ValueError(f"模型 {self.model_name} 尚未训练,请先调用train方法")
|
||||
|
||||
# 创建数据集
|
||||
test_data = [(text, 0) for text in texts] # 标签无关紧要
|
||||
test_dataset = LSTMDataset(test_data, self.word2vec_model)
|
||||
test_loader = DataLoader(test_dataset, batch_size=32, collate_fn=collate_fn)
|
||||
|
||||
predictions = []
|
||||
self.model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
for x, _, lengths in test_loader:
|
||||
x = x.to(self.device)
|
||||
outputs = self.model(x, lengths)
|
||||
outputs = outputs.view(-1)
|
||||
|
||||
# 转换为类别标签
|
||||
preds = (outputs > 0.5).cpu().numpy()
|
||||
predictions.extend(preds.astype(int).tolist())
|
||||
|
||||
return predictions
|
||||
|
||||
def predict_single(self, text: str) -> Tuple[int, float]:
|
||||
"""预测单条文本的情感"""
|
||||
if not self.is_trained:
|
||||
raise ValueError(f"模型 {self.model_name} 尚未训练,请先调用train方法")
|
||||
|
||||
# 转换为词向量
|
||||
vectors = []
|
||||
for word in text.split(" "):
|
||||
if word in self.word2vec_model.wv.key_to_index:
|
||||
vectors.append(self.word2vec_model.wv[word])
|
||||
|
||||
if len(vectors) == 0:
|
||||
return 0, 0.5 # 如果没有有效词向量,返回默认值
|
||||
|
||||
# 转换为tensor
|
||||
x = torch.Tensor(vectors).unsqueeze(0).to(self.device) # 添加batch维度
|
||||
lengths = [len(vectors)]
|
||||
|
||||
self.model.eval()
|
||||
with torch.no_grad():
|
||||
output = self.model(x, lengths)
|
||||
prob = output.item()
|
||||
prediction = int(prob > 0.5)
|
||||
confidence = prob if prediction == 1 else 1 - prob
|
||||
|
||||
return prediction, confidence
|
||||
|
||||
def save_model(self, model_path: str = None) -> None:
|
||||
"""保存模型"""
|
||||
if not self.is_trained:
|
||||
raise ValueError(f"模型 {self.model_name} 尚未训练,无法保存")
|
||||
|
||||
if model_path is None:
|
||||
model_path = f"./model/{self.model_name.lower()}_model.pth"
|
||||
|
||||
os.makedirs(os.path.dirname(model_path), exist_ok=True)
|
||||
|
||||
# 保存模型状态和Word2Vec
|
||||
model_data = {
|
||||
'model_state_dict': self.model.state_dict(),
|
||||
'word2vec_model': self.word2vec_model,
|
||||
'model_config': {
|
||||
'embed_size': 64,
|
||||
'hidden_size': 64,
|
||||
'num_layers': 2
|
||||
},
|
||||
'device': str(self.device)
|
||||
}
|
||||
|
||||
torch.save(model_data, model_path)
|
||||
print(f"模型已保存到: {model_path}")
|
||||
|
||||
def load_model(self, model_path: str) -> None:
|
||||
"""加载模型"""
|
||||
if not os.path.exists(model_path):
|
||||
raise FileNotFoundError(f"模型文件不存在: {model_path}")
|
||||
|
||||
model_data = torch.load(model_path, map_location=self.device)
|
||||
|
||||
# 加载Word2Vec
|
||||
self.word2vec_model = model_data['word2vec_model']
|
||||
|
||||
# 重建LSTM网络
|
||||
config = model_data['model_config']
|
||||
self.model = LSTMNet(
|
||||
config['embed_size'],
|
||||
config['hidden_size'],
|
||||
config['num_layers']
|
||||
).to(self.device)
|
||||
|
||||
# 加载模型权重
|
||||
self.model.load_state_dict(model_data['model_state_dict'])
|
||||
|
||||
self.is_trained = True
|
||||
print(f"已加载模型: {model_path}")
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
parser = argparse.ArgumentParser(description='LSTM情感分析模型训练')
|
||||
parser.add_argument('--train_path', type=str, default='./data/weibo2018/train.txt',
|
||||
help='训练数据路径')
|
||||
parser.add_argument('--test_path', type=str, default='./data/weibo2018/test.txt',
|
||||
help='测试数据路径')
|
||||
parser.add_argument('--model_path', type=str, default='./model/lstm_model.pth',
|
||||
help='模型保存路径')
|
||||
parser.add_argument('--epochs', type=int, default=5,
|
||||
help='训练轮数')
|
||||
parser.add_argument('--batch_size', type=int, default=100,
|
||||
help='批大小')
|
||||
parser.add_argument('--hidden_size', type=int, default=64,
|
||||
help='LSTM隐藏层大小')
|
||||
parser.add_argument('--learning_rate', type=float, default=5e-4,
|
||||
help='学习率')
|
||||
parser.add_argument('--eval_only', action='store_true',
|
||||
help='仅评估已有模型,不进行训练')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 创建模型
|
||||
model = LSTMModel()
|
||||
|
||||
if args.eval_only:
|
||||
# 仅评估模式
|
||||
print("评估模式:加载已有模型进行评估")
|
||||
model.load_model(args.model_path)
|
||||
|
||||
# 加载测试数据
|
||||
_, test_data = BaseModel.load_data(args.train_path, args.test_path)
|
||||
|
||||
# 评估模型
|
||||
model.evaluate(test_data)
|
||||
else:
|
||||
# 训练模式
|
||||
# 加载数据
|
||||
train_data, test_data = BaseModel.load_data(args.train_path, args.test_path)
|
||||
|
||||
# 训练模型
|
||||
model.train(
|
||||
train_data,
|
||||
num_epochs=args.epochs,
|
||||
batch_size=args.batch_size,
|
||||
hidden_size=args.hidden_size,
|
||||
learning_rate=args.learning_rate
|
||||
)
|
||||
|
||||
# 评估模型
|
||||
model.evaluate(test_data)
|
||||
|
||||
# 保存模型
|
||||
model.save_model(args.model_path)
|
||||
|
||||
# 示例预测
|
||||
print("\n示例预测:")
|
||||
test_texts = [
|
||||
"今天天气真好,心情很棒",
|
||||
"这部电影太无聊了,浪费时间",
|
||||
"哈哈哈,太有趣了"
|
||||
]
|
||||
|
||||
for text in test_texts:
|
||||
pred, conf = model.predict_single(text)
|
||||
sentiment = "正面" if pred == 1 else "负面"
|
||||
print(f"文本: {text}")
|
||||
print(f"预测: {sentiment} (置信度: {conf:.4f})")
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,310 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
统一的情感分析预测程序
|
||||
支持加载所有模型进行情感预测
|
||||
"""
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
from typing import Dict, Tuple, List
|
||||
import warnings
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
# 导入所有模型类
|
||||
from bayes_train import BayesModel
|
||||
from svm_train import SVMModel
|
||||
from xgboost_train import XGBoostModel
|
||||
from lstm_train import LSTMModel
|
||||
from bert_train import BertModel_Custom
|
||||
from utils import processing
|
||||
|
||||
|
||||
class SentimentPredictor:
|
||||
"""情感分析预测器"""
|
||||
|
||||
def __init__(self):
|
||||
self.models = {}
|
||||
self.available_models = {
|
||||
'bayes': BayesModel,
|
||||
'svm': SVMModel,
|
||||
'xgboost': XGBoostModel,
|
||||
'lstm': LSTMModel,
|
||||
'bert': BertModel_Custom
|
||||
}
|
||||
|
||||
def load_model(self, model_type: str, model_path: str, **kwargs) -> None:
|
||||
"""加载指定类型的模型
|
||||
|
||||
Args:
|
||||
model_type: 模型类型 ('bayes', 'svm', 'xgboost', 'lstm', 'bert')
|
||||
model_path: 模型文件路径
|
||||
**kwargs: 其他参数(如BERT的预训练模型路径)
|
||||
"""
|
||||
if model_type not in self.available_models:
|
||||
raise ValueError(f"不支持的模型类型: {model_type}")
|
||||
|
||||
if not os.path.exists(model_path):
|
||||
print(f"警告: 模型文件不存在: {model_path}")
|
||||
return
|
||||
|
||||
print(f"加载 {model_type.upper()} 模型...")
|
||||
|
||||
try:
|
||||
if model_type == 'bert':
|
||||
# BERT需要额外的预训练模型路径
|
||||
bert_path = kwargs.get('bert_path', './model/chinese_wwm_pytorch')
|
||||
model = BertModel_Custom(bert_path)
|
||||
else:
|
||||
model = self.available_models[model_type]()
|
||||
|
||||
model.load_model(model_path)
|
||||
self.models[model_type] = model
|
||||
print(f"{model_type.upper()} 模型加载成功")
|
||||
|
||||
except Exception as e:
|
||||
print(f"加载 {model_type.upper()} 模型失败: {e}")
|
||||
|
||||
def load_all_models(self, model_dir: str = './model', bert_path: str = './model/chinese_wwm_pytorch') -> None:
|
||||
"""加载所有可用的模型
|
||||
|
||||
Args:
|
||||
model_dir: 模型文件目录
|
||||
bert_path: BERT预训练模型路径
|
||||
"""
|
||||
model_files = {
|
||||
'bayes': os.path.join(model_dir, 'bayes_model.pkl'),
|
||||
'svm': os.path.join(model_dir, 'svm_model.pkl'),
|
||||
'xgboost': os.path.join(model_dir, 'xgboost_model.pkl'),
|
||||
'lstm': os.path.join(model_dir, 'lstm_model.pth'),
|
||||
'bert': os.path.join(model_dir, 'bert_model.pth')
|
||||
}
|
||||
|
||||
print("开始加载所有可用模型...")
|
||||
for model_type, model_path in model_files.items():
|
||||
self.load_model(model_type, model_path, bert_path=bert_path)
|
||||
|
||||
print(f"\n已加载 {len(self.models)} 个模型: {list(self.models.keys())}")
|
||||
|
||||
def predict_single(self, text: str, model_type: str = None) -> Dict[str, Tuple[int, float]]:
|
||||
"""预测单条文本的情感
|
||||
|
||||
Args:
|
||||
text: 待预测文本
|
||||
model_type: 指定模型类型,如果为None则使用所有已加载的模型
|
||||
|
||||
Returns:
|
||||
Dict[model_type, (prediction, confidence)]
|
||||
"""
|
||||
# 文本预处理
|
||||
processed_text = processing(text)
|
||||
|
||||
if model_type:
|
||||
if model_type not in self.models:
|
||||
raise ValueError(f"模型 {model_type} 未加载")
|
||||
|
||||
prediction, confidence = self.models[model_type].predict_single(processed_text)
|
||||
return {model_type: (prediction, confidence)}
|
||||
|
||||
# 使用所有模型预测
|
||||
results = {}
|
||||
for name, model in self.models.items():
|
||||
try:
|
||||
prediction, confidence = model.predict_single(processed_text)
|
||||
results[name] = (prediction, confidence)
|
||||
except Exception as e:
|
||||
print(f"模型 {name} 预测失败: {e}")
|
||||
results[name] = (0, 0.0)
|
||||
|
||||
return results
|
||||
|
||||
def predict_batch(self, texts: List[str], model_type: str = None) -> Dict[str, List[int]]:
|
||||
"""批量预测文本情感
|
||||
|
||||
Args:
|
||||
texts: 待预测文本列表
|
||||
model_type: 指定模型类型,如果为None则使用所有已加载的模型
|
||||
|
||||
Returns:
|
||||
Dict[model_type, predictions]
|
||||
"""
|
||||
# 文本预处理
|
||||
processed_texts = [processing(text) for text in texts]
|
||||
|
||||
if model_type:
|
||||
if model_type not in self.models:
|
||||
raise ValueError(f"模型 {model_type} 未加载")
|
||||
|
||||
predictions = self.models[model_type].predict(processed_texts)
|
||||
return {model_type: predictions}
|
||||
|
||||
# 使用所有模型预测
|
||||
results = {}
|
||||
for name, model in self.models.items():
|
||||
try:
|
||||
predictions = model.predict(processed_texts)
|
||||
results[name] = predictions
|
||||
except Exception as e:
|
||||
print(f"模型 {name} 预测失败: {e}")
|
||||
results[name] = [0] * len(texts)
|
||||
|
||||
return results
|
||||
|
||||
def ensemble_predict(self, text: str, weights: Dict[str, float] = None) -> Tuple[int, float]:
|
||||
"""集成预测(多个模型投票)
|
||||
|
||||
Args:
|
||||
text: 待预测文本
|
||||
weights: 模型权重,如果为None则平均权重
|
||||
|
||||
Returns:
|
||||
(prediction, confidence)
|
||||
"""
|
||||
if len(self.models) == 0:
|
||||
raise ValueError("没有加载任何模型")
|
||||
|
||||
results = self.predict_single(text)
|
||||
|
||||
if weights is None:
|
||||
weights = {name: 1.0 for name in results.keys()}
|
||||
|
||||
# 加权平均
|
||||
total_weight = 0
|
||||
weighted_prob = 0
|
||||
|
||||
for model_name, (pred, conf) in results.items():
|
||||
if model_name in weights:
|
||||
weight = weights[model_name]
|
||||
prob = conf if pred == 1 else 1 - conf
|
||||
weighted_prob += prob * weight
|
||||
total_weight += weight
|
||||
|
||||
if total_weight == 0:
|
||||
return 0, 0.5
|
||||
|
||||
final_prob = weighted_prob / total_weight
|
||||
final_pred = int(final_prob > 0.5)
|
||||
final_conf = final_prob if final_pred == 1 else 1 - final_prob
|
||||
|
||||
return final_pred, final_conf
|
||||
|
||||
def interactive_predict(self):
|
||||
"""交互式预测模式"""
|
||||
if len(self.models) == 0:
|
||||
print("错误: 没有加载任何模型,请先加载模型")
|
||||
return
|
||||
|
||||
print("\n" + "="*50)
|
||||
print("="*50)
|
||||
print(f"已加载模型: {', '.join(self.models.keys())}")
|
||||
print("输入 'q' 退出程序")
|
||||
print("输入 'models' 查看模型列表")
|
||||
print("输入 'ensemble' 使用集成预测")
|
||||
print("-"*50)
|
||||
|
||||
while True:
|
||||
try:
|
||||
text = input("\n请输入要分析的微博内容: ").strip()
|
||||
|
||||
if text.lower() == 'q':
|
||||
print("👋 再见!")
|
||||
break
|
||||
|
||||
if text.lower() == 'models':
|
||||
print(f"已加载模型: {list(self.models.keys())}")
|
||||
continue
|
||||
|
||||
if text.lower() == 'ensemble':
|
||||
if len(self.models) > 1:
|
||||
pred, conf = self.ensemble_predict(text)
|
||||
sentiment = "😊 正面" if pred == 1 else "😞 负面"
|
||||
print(f"\n🤖 集成预测结果:")
|
||||
print(f" 情感倾向: {sentiment}")
|
||||
print(f" 置信度: {conf:.4f}")
|
||||
else:
|
||||
print("❌ 集成预测需要至少2个模型")
|
||||
continue
|
||||
|
||||
if not text:
|
||||
print("❌ 请输入有效内容")
|
||||
continue
|
||||
|
||||
# 预测
|
||||
results = self.predict_single(text)
|
||||
|
||||
print(f"\n📝 原文: {text}")
|
||||
print("🔍 预测结果:")
|
||||
|
||||
for model_name, (pred, conf) in results.items():
|
||||
sentiment = "😊 正面" if pred == 1 else "😞 负面"
|
||||
print(f" {model_name.upper():8}: {sentiment} (置信度: {conf:.4f})")
|
||||
|
||||
# 如果有多个模型,显示集成结果
|
||||
if len(results) > 1:
|
||||
ensemble_pred, ensemble_conf = self.ensemble_predict(text)
|
||||
ensemble_sentiment = "😊 正面" if ensemble_pred == 1 else "😞 负面"
|
||||
print(f" {'集成':8}: {ensemble_sentiment} (置信度: {ensemble_conf:.4f})")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\n👋 程序被中断,再见!")
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"❌ 预测过程中出现错误: {e}")
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
parser = argparse.ArgumentParser(description='微博情感分析统一预测程序')
|
||||
parser.add_argument('--model_dir', type=str, default='./model',
|
||||
help='模型文件目录')
|
||||
parser.add_argument('--bert_path', type=str, default='./model/chinese_wwm_pytorch',
|
||||
help='BERT预训练模型路径')
|
||||
parser.add_argument('--model_type', type=str, choices=['bayes', 'svm', 'xgboost', 'lstm', 'bert'],
|
||||
help='指定单个模型类型进行预测')
|
||||
parser.add_argument('--text', type=str,
|
||||
help='直接预测指定文本')
|
||||
parser.add_argument('--interactive', action='store_true', default=True,
|
||||
help='交互式预测模式(默认)')
|
||||
parser.add_argument('--ensemble', action='store_true',
|
||||
help='使用集成预测')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 创建预测器
|
||||
predictor = SentimentPredictor()
|
||||
|
||||
# 加载模型
|
||||
if args.model_type:
|
||||
# 加载指定模型
|
||||
model_files = {
|
||||
'bayes': 'bayes_model.pkl',
|
||||
'svm': 'svm_model.pkl',
|
||||
'xgboost': 'xgboost_model.pkl',
|
||||
'lstm': 'lstm_model.pth',
|
||||
'bert': 'bert_model.pth'
|
||||
}
|
||||
model_path = os.path.join(args.model_dir, model_files[args.model_type])
|
||||
predictor.load_model(args.model_type, model_path, bert_path=args.bert_path)
|
||||
else:
|
||||
# 加载所有模型
|
||||
predictor.load_all_models(args.model_dir, args.bert_path)
|
||||
|
||||
# 如果指定了文本,直接预测
|
||||
if args.text:
|
||||
if args.ensemble and len(predictor.models) > 1:
|
||||
pred, conf = predictor.ensemble_predict(args.text)
|
||||
sentiment = "正面" if pred == 1 else "负面"
|
||||
print(f"文本: {args.text}")
|
||||
print(f"集成预测: {sentiment} (置信度: {conf:.4f})")
|
||||
else:
|
||||
results = predictor.predict_single(args.text, args.model_type)
|
||||
print(f"文本: {args.text}")
|
||||
for model_name, (pred, conf) in results.items():
|
||||
sentiment = "正面" if pred == 1 else "负面"
|
||||
print(f"{model_name.upper()}: {sentiment} (置信度: {conf:.4f})")
|
||||
elif args.interactive:
|
||||
# 交互式模式
|
||||
predictor.interactive_predict()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,6 +1,9 @@
|
||||
jieba==0.42.1
|
||||
pytorch==1.7.1
|
||||
sklearn==0.23.2
|
||||
xgboost==1.2.1
|
||||
transformers==4.6.1
|
||||
gensim==4.0.1
|
||||
torch>=1.7.1
|
||||
scikit-learn>=0.23.2
|
||||
xgboost>=1.2.1
|
||||
transformers>=4.6.1
|
||||
gensim>=4.0.1
|
||||
pandas>=1.3.0
|
||||
numpy>=1.20.0
|
||||
tqdm>=4.60.0
|
||||
@@ -0,0 +1,166 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
SVM情感分析模型训练脚本
|
||||
"""
|
||||
import argparse
|
||||
import pandas as pd
|
||||
from typing import List, Tuple
|
||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||
from sklearn import svm
|
||||
from sklearn.metrics import accuracy_score, f1_score
|
||||
|
||||
from base_model import BaseModel
|
||||
from utils import stopwords
|
||||
|
||||
|
||||
class SVMModel(BaseModel):
|
||||
"""SVM情感分析模型"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("SVM")
|
||||
|
||||
def train(self, train_data: List[Tuple[str, int]], **kwargs) -> None:
|
||||
"""训练SVM模型
|
||||
|
||||
Args:
|
||||
train_data: 训练数据,格式为[(text, label), ...]
|
||||
**kwargs: 其他参数,支持kernel, C等SVM参数
|
||||
"""
|
||||
print(f"开始训练 {self.model_name} 模型...")
|
||||
|
||||
# 准备数据
|
||||
df_train = pd.DataFrame(train_data, columns=["words", "label"])
|
||||
|
||||
# 特征编码(TF-IDF模型)
|
||||
print("构建TF-IDF特征...")
|
||||
self.vectorizer = TfidfVectorizer(
|
||||
token_pattern=r'\[?\w+\]?',
|
||||
stop_words=stopwords
|
||||
)
|
||||
|
||||
X_train = self.vectorizer.fit_transform(df_train["words"])
|
||||
y_train = df_train["label"]
|
||||
|
||||
print(f"特征维度: {X_train.shape[1]}")
|
||||
|
||||
# 获取SVM参数
|
||||
kernel = kwargs.get('kernel', 'rbf')
|
||||
C = kwargs.get('C', 1.0)
|
||||
gamma = kwargs.get('gamma', 'scale')
|
||||
|
||||
# 训练模型
|
||||
print(f"训练SVM分类器 (kernel={kernel}, C={C}, gamma={gamma})...")
|
||||
self.model = svm.SVC(kernel=kernel, C=C, gamma=gamma, probability=True)
|
||||
self.model.fit(X_train, y_train)
|
||||
|
||||
self.is_trained = True
|
||||
print(f"{self.model_name} 模型训练完成!")
|
||||
|
||||
def predict(self, texts: List[str]) -> List[int]:
|
||||
"""预测文本情感
|
||||
|
||||
Args:
|
||||
texts: 待预测文本列表
|
||||
|
||||
Returns:
|
||||
预测结果列表
|
||||
"""
|
||||
if not self.is_trained:
|
||||
raise ValueError(f"模型 {self.model_name} 尚未训练,请先调用train方法")
|
||||
|
||||
# 特征转换
|
||||
X = self.vectorizer.transform(texts)
|
||||
|
||||
# 预测
|
||||
predictions = self.model.predict(X)
|
||||
|
||||
return predictions.tolist()
|
||||
|
||||
def predict_single(self, text: str) -> Tuple[int, float]:
|
||||
"""预测单条文本的情感
|
||||
|
||||
Args:
|
||||
text: 待预测文本
|
||||
|
||||
Returns:
|
||||
(predicted_label, confidence)
|
||||
"""
|
||||
if not self.is_trained:
|
||||
raise ValueError(f"模型 {self.model_name} 尚未训练,请先调用train方法")
|
||||
|
||||
# 特征转换
|
||||
X = self.vectorizer.transform([text])
|
||||
|
||||
# 预测
|
||||
prediction = self.model.predict(X)[0]
|
||||
probabilities = self.model.predict_proba(X)[0]
|
||||
confidence = max(probabilities)
|
||||
|
||||
return int(prediction), float(confidence)
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
parser = argparse.ArgumentParser(description='SVM情感分析模型训练')
|
||||
parser.add_argument('--train_path', type=str, default='./data/weibo2018/train.txt',
|
||||
help='训练数据路径')
|
||||
parser.add_argument('--test_path', type=str, default='./data/weibo2018/test.txt',
|
||||
help='测试数据路径')
|
||||
parser.add_argument('--model_path', type=str, default='./model/svm_model.pkl',
|
||||
help='模型保存路径')
|
||||
parser.add_argument('--kernel', type=str, default='rbf', choices=['linear', 'poly', 'rbf', 'sigmoid'],
|
||||
help='SVM核函数类型')
|
||||
parser.add_argument('--C', type=float, default=1.0,
|
||||
help='SVM正则化参数C')
|
||||
parser.add_argument('--gamma', type=str, default='scale',
|
||||
help='SVM核函数参数gamma')
|
||||
parser.add_argument('--eval_only', action='store_true',
|
||||
help='仅评估已有模型,不进行训练')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 创建模型
|
||||
model = SVMModel()
|
||||
|
||||
if args.eval_only:
|
||||
# 仅评估模式
|
||||
print("评估模式:加载已有模型进行评估")
|
||||
model.load_model(args.model_path)
|
||||
|
||||
# 加载测试数据
|
||||
_, test_data = BaseModel.load_data(args.train_path, args.test_path)
|
||||
|
||||
# 评估模型
|
||||
model.evaluate(test_data)
|
||||
else:
|
||||
# 训练模式
|
||||
# 加载数据
|
||||
train_data, test_data = BaseModel.load_data(args.train_path, args.test_path)
|
||||
|
||||
# 训练模型
|
||||
model.train(train_data, kernel=args.kernel, C=args.C, gamma=args.gamma)
|
||||
|
||||
# 评估模型
|
||||
model.evaluate(test_data)
|
||||
|
||||
# 保存模型
|
||||
model.save_model(args.model_path)
|
||||
|
||||
# 示例预测
|
||||
print("\n示例预测:")
|
||||
test_texts = [
|
||||
"今天天气真好,心情很棒",
|
||||
"这部电影太无聊了,浪费时间",
|
||||
"哈哈哈,太有趣了"
|
||||
]
|
||||
|
||||
for text in test_texts:
|
||||
pred, conf = model.predict_single(text)
|
||||
sentiment = "正面" if pred == 1 else "负面"
|
||||
print(f"文本: {text}")
|
||||
print(f"预测: {sentiment} (置信度: {conf:.4f})")
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,12 +1,20 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import jieba
|
||||
import re
|
||||
import os
|
||||
import pickle
|
||||
from typing import List, Tuple, Any
|
||||
|
||||
|
||||
# 加载停用词
|
||||
stopwords = []
|
||||
with open("data/stopwords.txt", "r", encoding="utf8") as f:
|
||||
for w in f:
|
||||
stopwords.append(w.strip())
|
||||
stopwords_path = "data/stopwords.txt"
|
||||
if os.path.exists(stopwords_path):
|
||||
with open(stopwords_path, "r", encoding="utf8") as f:
|
||||
for w in f:
|
||||
stopwords.append(w.strip())
|
||||
else:
|
||||
print(f"警告: 停用词文件 {stopwords_path} 不存在,将使用空停用词列表")
|
||||
|
||||
|
||||
def load_corpus(path):
|
||||
@@ -66,4 +74,65 @@ def processing_bert(text):
|
||||
text = re.sub("@.+?( |$)", " ", text) # 去除 @xxx (用户名)
|
||||
text = re.sub("【.+?】", " ", text) # 去除 【xx】 (里面的内容通常都不是用户自己写的)
|
||||
text = re.sub("\u200b", " ", text) # '\u200b'是这个数据集中的一个bad case, 不用特别在意
|
||||
return text
|
||||
return text
|
||||
|
||||
|
||||
def save_model(model: Any, model_path: str) -> None:
|
||||
"""
|
||||
保存模型到文件
|
||||
|
||||
Args:
|
||||
model: 要保存的模型对象
|
||||
model_path: 保存路径
|
||||
"""
|
||||
os.makedirs(os.path.dirname(model_path), exist_ok=True)
|
||||
|
||||
with open(model_path, 'wb') as f:
|
||||
pickle.dump(model, f)
|
||||
|
||||
print(f"模型已保存到: {model_path}")
|
||||
|
||||
|
||||
def load_model(model_path: str) -> Any:
|
||||
"""
|
||||
从文件加载模型
|
||||
|
||||
Args:
|
||||
model_path: 模型文件路径
|
||||
|
||||
Returns:
|
||||
加载的模型对象
|
||||
"""
|
||||
if not os.path.exists(model_path):
|
||||
raise FileNotFoundError(f"模型文件不存在: {model_path}")
|
||||
|
||||
with open(model_path, 'rb') as f:
|
||||
model = pickle.load(f)
|
||||
|
||||
print(f"已加载模型: {model_path}")
|
||||
return model
|
||||
|
||||
|
||||
def preprocess_text_simple(text: str) -> str:
|
||||
"""
|
||||
简单的文本预处理函数,用于预测时的文本清洗
|
||||
|
||||
Args:
|
||||
text: 原始文本
|
||||
|
||||
Returns:
|
||||
清洗后的文本
|
||||
"""
|
||||
# 数据清洗
|
||||
text = re.sub("\{%.+?%\}", " ", text) # 去除 {%xxx%}
|
||||
text = re.sub("@.+?( |$)", " ", text) # 去除 @xxx
|
||||
text = re.sub("【.+?】", " ", text) # 去除 【xx】
|
||||
text = re.sub("\u200b", " ", text) # 去除特殊字符
|
||||
|
||||
# 删除表情符号
|
||||
text = re.sub(r'[\U0001F600-\U0001F64F\U0001F300-\U0001F5FF\U0001F680-\U0001F6FF\U0001F1E0-\U0001F1FF\U00002600-\U000027BF\U0001f900-\U0001f9ff\U0001f018-\U0001f270\U0000231a-\U0000231b\U0000238d-\U0000238d\U000024c2-\U0001f251]+', '', text)
|
||||
|
||||
# 多个空格合并为一个
|
||||
text = re.sub(r"\s+", " ", text)
|
||||
|
||||
return text.strip()
|
||||
@@ -0,0 +1,233 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
XGBoost情感分析模型训练脚本
|
||||
"""
|
||||
import argparse
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from typing import List, Tuple
|
||||
from sklearn.feature_extraction.text import CountVectorizer
|
||||
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
|
||||
import xgboost as xgb
|
||||
|
||||
from base_model import BaseModel
|
||||
from utils import stopwords
|
||||
|
||||
|
||||
class XGBoostModel(BaseModel):
|
||||
"""XGBoost情感分析模型"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("XGBoost")
|
||||
|
||||
def train(self, train_data: List[Tuple[str, int]], **kwargs) -> None:
|
||||
"""训练XGBoost模型
|
||||
|
||||
Args:
|
||||
train_data: 训练数据,格式为[(text, label), ...]
|
||||
**kwargs: 其他参数,支持XGBoost的各种参数
|
||||
"""
|
||||
print(f"开始训练 {self.model_name} 模型...")
|
||||
|
||||
# 准备数据
|
||||
df_train = pd.DataFrame(train_data, columns=["words", "label"])
|
||||
|
||||
# 特征编码(词袋模型,限制特征数量)
|
||||
max_features = kwargs.get('max_features', 2000)
|
||||
print(f"构建词袋模型 (max_features={max_features})...")
|
||||
self.vectorizer = CountVectorizer(
|
||||
token_pattern=r'\[?\w+\]?',
|
||||
stop_words=stopwords,
|
||||
max_features=max_features
|
||||
)
|
||||
|
||||
X_train = self.vectorizer.fit_transform(df_train["words"])
|
||||
y_train = df_train["label"]
|
||||
|
||||
print(f"特征维度: {X_train.shape[1]}")
|
||||
|
||||
# XGBoost参数设置
|
||||
params = {
|
||||
'booster': kwargs.get('booster', 'gbtree'),
|
||||
'max_depth': kwargs.get('max_depth', 6),
|
||||
'scale_pos_weight': kwargs.get('scale_pos_weight', 0.5),
|
||||
'colsample_bytree': kwargs.get('colsample_bytree', 0.8),
|
||||
'objective': 'binary:logistic',
|
||||
'eval_metric': 'error',
|
||||
'eta': kwargs.get('eta', 0.3),
|
||||
'nthread': kwargs.get('nthread', 10),
|
||||
}
|
||||
|
||||
num_boost_round = kwargs.get('num_boost_round', 200)
|
||||
|
||||
print(f"训练XGBoost分类器...")
|
||||
print(f"参数: {params}")
|
||||
print(f"迭代轮数: {num_boost_round}")
|
||||
|
||||
# 创建DMatrix
|
||||
dmatrix = xgb.DMatrix(X_train, label=y_train)
|
||||
|
||||
# 训练模型
|
||||
self.model = xgb.train(params, dmatrix, num_boost_round=num_boost_round)
|
||||
|
||||
self.is_trained = True
|
||||
print(f"{self.model_name} 模型训练完成!")
|
||||
|
||||
def predict(self, texts: List[str]) -> List[int]:
|
||||
"""预测文本情感
|
||||
|
||||
Args:
|
||||
texts: 待预测文本列表
|
||||
|
||||
Returns:
|
||||
预测结果列表
|
||||
"""
|
||||
if not self.is_trained:
|
||||
raise ValueError(f"模型 {self.model_name} 尚未训练,请先调用train方法")
|
||||
|
||||
# 特征转换
|
||||
X = self.vectorizer.transform(texts)
|
||||
|
||||
# 创建DMatrix
|
||||
dmatrix = xgb.DMatrix(X)
|
||||
|
||||
# 预测概率
|
||||
y_prob = self.model.predict(dmatrix)
|
||||
|
||||
# 转换为类别标签
|
||||
y_pred = (y_prob > 0.5).astype(int)
|
||||
|
||||
return y_pred.tolist()
|
||||
|
||||
def predict_single(self, text: str) -> Tuple[int, float]:
|
||||
"""预测单条文本的情感
|
||||
|
||||
Args:
|
||||
text: 待预测文本
|
||||
|
||||
Returns:
|
||||
(predicted_label, confidence)
|
||||
"""
|
||||
if not self.is_trained:
|
||||
raise ValueError(f"模型 {self.model_name} 尚未训练,请先调用train方法")
|
||||
|
||||
# 特征转换
|
||||
X = self.vectorizer.transform([text])
|
||||
|
||||
# 创建DMatrix
|
||||
dmatrix = xgb.DMatrix(X)
|
||||
|
||||
# 预测概率
|
||||
prob = self.model.predict(dmatrix)[0]
|
||||
|
||||
# 转换为类别标签和置信度
|
||||
prediction = int(prob > 0.5)
|
||||
confidence = prob if prediction == 1 else 1 - prob
|
||||
|
||||
return prediction, float(confidence)
|
||||
|
||||
def evaluate(self, test_data: List[Tuple[str, int]]) -> dict:
|
||||
"""评估模型性能,包含AUC指标"""
|
||||
if not self.is_trained:
|
||||
raise ValueError(f"模型 {self.model_name} 尚未训练,请先调用train方法")
|
||||
|
||||
texts = [item[0] for item in test_data]
|
||||
labels = [item[1] for item in test_data]
|
||||
|
||||
# 预测类别
|
||||
predictions = self.predict(texts)
|
||||
|
||||
# 预测概率(用于计算AUC)
|
||||
X = self.vectorizer.transform(texts)
|
||||
dmatrix = xgb.DMatrix(X)
|
||||
probabilities = self.model.predict(dmatrix)
|
||||
|
||||
accuracy = accuracy_score(labels, predictions)
|
||||
f1 = f1_score(labels, predictions, average='weighted')
|
||||
auc = roc_auc_score(labels, probabilities)
|
||||
|
||||
print(f"\n{self.model_name} 模型评估结果:")
|
||||
print(f"准确率: {accuracy:.4f}")
|
||||
print(f"F1分数: {f1:.4f}")
|
||||
print(f"AUC: {auc:.4f}")
|
||||
|
||||
return {
|
||||
'accuracy': accuracy,
|
||||
'f1_score': f1,
|
||||
'auc': auc
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
parser = argparse.ArgumentParser(description='XGBoost情感分析模型训练')
|
||||
parser.add_argument('--train_path', type=str, default='./data/weibo2018/train.txt',
|
||||
help='训练数据路径')
|
||||
parser.add_argument('--test_path', type=str, default='./data/weibo2018/test.txt',
|
||||
help='测试数据路径')
|
||||
parser.add_argument('--model_path', type=str, default='./model/xgboost_model.pkl',
|
||||
help='模型保存路径')
|
||||
parser.add_argument('--max_features', type=int, default=2000,
|
||||
help='最大特征数量')
|
||||
parser.add_argument('--max_depth', type=int, default=6,
|
||||
help='XGBoost最大深度')
|
||||
parser.add_argument('--eta', type=float, default=0.3,
|
||||
help='XGBoost学习率')
|
||||
parser.add_argument('--num_boost_round', type=int, default=200,
|
||||
help='XGBoost迭代轮数')
|
||||
parser.add_argument('--eval_only', action='store_true',
|
||||
help='仅评估已有模型,不进行训练')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 创建模型
|
||||
model = XGBoostModel()
|
||||
|
||||
if args.eval_only:
|
||||
# 仅评估模式
|
||||
print("评估模式:加载已有模型进行评估")
|
||||
model.load_model(args.model_path)
|
||||
|
||||
# 加载测试数据
|
||||
_, test_data = BaseModel.load_data(args.train_path, args.test_path)
|
||||
|
||||
# 评估模型
|
||||
model.evaluate(test_data)
|
||||
else:
|
||||
# 训练模式
|
||||
# 加载数据
|
||||
train_data, test_data = BaseModel.load_data(args.train_path, args.test_path)
|
||||
|
||||
# 训练模型
|
||||
model.train(
|
||||
train_data,
|
||||
max_features=args.max_features,
|
||||
max_depth=args.max_depth,
|
||||
eta=args.eta,
|
||||
num_boost_round=args.num_boost_round
|
||||
)
|
||||
|
||||
# 评估模型
|
||||
model.evaluate(test_data)
|
||||
|
||||
# 保存模型
|
||||
model.save_model(args.model_path)
|
||||
|
||||
# 示例预测
|
||||
print("\n示例预测:")
|
||||
test_texts = [
|
||||
"今天天气真好,心情很棒",
|
||||
"这部电影太无聊了,浪费时间",
|
||||
"哈哈哈,太有趣了"
|
||||
]
|
||||
|
||||
for text in test_texts:
|
||||
pred, conf = model.predict_single(text)
|
||||
sentiment = "正面" if pred == 1 else "负面"
|
||||
print(f"文本: {text}")
|
||||
print(f"预测: {sentiment} (置信度: {conf:.4f})")
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user