{ "cells": [ { "cell_type": "markdown", "metadata": { "cell_id": 39 }, "source": [ "### 加载数据集" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "cell_id": 1 }, "outputs": [], "source": [ "from utils import load_corpus_bert\n", "\n", "TRAIN_PATH = \"./data/weibo2018/train.txt\"\n", "TEST_PATH = \"./data/weibo2018/test.txt\"" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "cell_id": 3 }, "outputs": [], "source": [ "# 分别加载训练集和测试集\n", "train_data = load_corpus_bert(TRAIN_PATH)\n", "test_data = load_corpus_bert(TEST_PATH)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "cell_id": 4 }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
textlabel
0“书中自有黄金屋,书中自有颜如玉”。沿着岁月的长河跋涉,或是风光旖旎,或是姹紫嫣红,万千...1
1这是英超被黑的最惨的一次[二哈][二哈]十几年来,中国只有孙继海,董方卓,郑智,李铁登陆过英...0
2中国远洋海运集团副总经理俞曾港4月21日在 上表示,中央企业“走出去”是要站在更高的平台参...1
3看《流星花园》其实也还好啦,现在的观念以及时尚眼光都不一样了,或许十几年之后的人看我们的现在...1
4汉武帝的罪己诏的真实性尽管存在着争议,然而“轮台罪己诏”作为中国历史上第一份皇帝自我批评的文...1
\n", "
" ], "text/plain": [ " text label\n", "0 “书中自有黄金屋,书中自有颜如玉”。沿着岁月的长河跋涉,或是风光旖旎,或是姹紫嫣红,万千... 1\n", "1 这是英超被黑的最惨的一次[二哈][二哈]十几年来,中国只有孙继海,董方卓,郑智,李铁登陆过英... 0\n", "2 中国远洋海运集团副总经理俞曾港4月21日在 上表示,中央企业“走出去”是要站在更高的平台参... 1\n", "3 看《流星花园》其实也还好啦,现在的观念以及时尚眼光都不一样了,或许十几年之后的人看我们的现在... 1\n", "4 汉武帝的罪己诏的真实性尽管存在着争议,然而“轮台罪己诏”作为中国历史上第一份皇帝自我批评的文... 1" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import pandas as pd\n", "\n", "df_train = pd.DataFrame(train_data, columns=[\"text\", \"label\"])\n", "df_test = pd.DataFrame(test_data, columns=[\"text\", \"label\"])\n", "df_train.head()" ] }, { "cell_type": "markdown", "metadata": { "cell_id": 41 }, "source": [ "### 加载Bert" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "cell_id": 5 }, "outputs": [], "source": [ "import os\n", "from transformers import BertTokenizer, BertModel\n", "\n", "os.environ[\"KMP_DUPLICATE_LIB_OK\"] = \"TRUE\" # 在我的电脑上不加这一句, bert模型会报错\n", "MODEL_PATH = \"./model/chinese_wwm_pytorch\" # 下载地址见 https://github.com/ymcui/Chinese-BERT-wwm" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "cell_id": 6 }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Some weights of the model checkpoint at ./model/chinese_wwm_pytorch were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight']\n", "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n" ] } ], "source": [ "# 加载\n", "tokenizer = BertTokenizer.from_pretrained(MODEL_PATH) # 分词器\n", "bert = BertModel.from_pretrained(MODEL_PATH) # 模型" ] }, { "cell_type": "markdown", "metadata": { "cell_id": 43 }, "source": [ "### 神经网络" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "cell_id": 7 }, "outputs": [], "source": [ "import torch\n", "from torch import nn\n", "from torch.utils.data import Dataset, DataLoader\n", "\n", "device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "cell_id": 8 }, "outputs": [], "source": [ "# 超参数\n", "learning_rate = 1e-3\n", "input_size = 768\n", "num_epoches = 10\n", "batch_size = 100\n", "decay_rate = 0.9" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "cell_id": 9 }, "outputs": [], "source": [ "# 数据集\n", "class MyDataset(Dataset):\n", " def __init__(self, df):\n", " self.data = df[\"text\"].tolist()\n", " self.label = df[\"label\"].tolist()\n", "\n", " def __getitem__(self, index):\n", " data = self.data[index]\n", " label = self.label[index]\n", " return data, label\n", "\n", " def __len__(self):\n", " return len(self.label)\n", "\n", "# 训练集\n", "train_data = MyDataset(df_train)\n", "train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)\n", "\n", "# 测试集\n", "test_data = MyDataset(df_test)\n", "test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "cell_id": 10 }, "outputs": [], "source": [ "# 网络结构\n", "class Net(nn.Module):\n", " def __init__(self, input_size):\n", " super(Net, self).__init__()\n", " self.fc = nn.Linear(input_size, 1)\n", " self.sigmoid = nn.Sigmoid()\n", "\n", " def forward(self, x):\n", " out = self.fc(x)\n", " out = self.sigmoid(out)\n", " return out\n", "\n", "net = Net(input_size).to(device)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "cell_id": 34 }, "outputs": [], "source": [ "from sklearn import metrics\n", "\n", "# 测试集效果检验\n", "def test():\n", " y_pred, y_true = [], []\n", "\n", " with torch.no_grad():\n", " for words, labels in test_loader:\n", " tokens = tokenizer(words, padding=True)\n", " input_ids = torch.tensor(tokens[\"input_ids\"]).to(device)\n", " attention_mask = torch.tensor(tokens[\"attention_mask\"]).to(device)\n", " last_hidden_states = bert(input_ids, attention_mask=attention_mask)\n", " bert_output = last_hidden_states[0][:, 0]\n", " outputs = net(bert_output) # 前向传播\n", " outputs = outputs.view(-1) # 将输出展平\n", " y_pred.append(outputs)\n", " y_true.append(labels)\n", "\n", " y_prob = torch.cat(y_pred)\n", " y_true = torch.cat(y_true)\n", " y_pred = y_prob.clone()\n", " y_pred[y_pred > 0.5] = 1\n", " y_pred[y_pred <= 0.5] = 0\n", " \n", " print(metrics.classification_report(y_true, y_pred))\n", " print(\"准确率:\", metrics.accuracy_score(y_true, y_pred))\n", " print(\"AUC:\", metrics.roc_auc_score(y_true, y_prob) )" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "cell_id": 11 }, "outputs": [], "source": [ "# 定义损失函数和优化器\n", "criterion = nn.BCELoss()\n", "optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)\n", "scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=decay_rate)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "cell_id": 14, "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch:1, step:10, loss:0.6710587739944458\n", "epoch:1, step:20, loss:0.6176288723945618\n", "epoch:1, step:30, loss:0.578593909740448\n", "epoch:1, step:40, loss:0.5502474308013916\n", "epoch:1, step:50, loss:0.5323082804679871\n", "epoch:1, step:60, loss:0.515110194683075\n", "epoch:1, step:70, loss:0.5127577185630798\n", "epoch:1, step:80, loss:0.48992329835891724\n", "epoch:1, step:90, loss:0.4868148863315582\n", "epoch:1, step:100, loss:0.49194520711898804\n", " precision recall f1-score support\n", "\n", " 0 0.74 0.74 0.74 155\n", " 1 0.88 0.88 0.88 345\n", "\n", " accuracy 0.84 500\n", " macro avg 0.81 0.81 0.81 500\n", "weighted avg 0.84 0.84 0.84 500\n", "\n", "准确率: 0.84\n", "AUC: 0.9027582982702197\n", "saved model: ./model/bert_dnn_1.model\n", "epoch:2, step:10, loss:0.46188774704933167\n", "epoch:2, step:20, loss:0.4335215985774994\n", "epoch:2, step:30, loss:0.4540901184082031\n", "epoch:2, step:40, loss:0.4392821788787842\n", "epoch:2, step:50, loss:0.47116056084632874\n", "epoch:2, step:60, loss:0.4669877886772156\n", "epoch:2, step:70, loss:0.4401330053806305\n", "epoch:2, step:80, loss:0.4518135190010071\n", "epoch:2, step:90, loss:0.4567466676235199\n", "epoch:2, step:100, loss:0.4663034975528717\n", " precision recall f1-score support\n", "\n", " 0 0.72 0.83 0.77 155\n", " 1 0.92 0.85 0.88 345\n", "\n", " accuracy 0.85 500\n", " macro avg 0.82 0.84 0.83 500\n", "weighted avg 0.86 0.85 0.85 500\n", "\n", "准确率: 0.846\n", "AUC: 0.9149322113136981\n", "saved model: ./model/bert_dnn_2.model\n", "epoch:3, step:10, loss:0.42892661690711975\n", "epoch:3, step:20, loss:0.4225884974002838\n", "epoch:3, step:30, loss:0.415252685546875\n", "epoch:3, step:40, loss:0.43130287528038025\n", "epoch:3, step:50, loss:0.42938193678855896\n", "epoch:3, step:60, loss:0.4340507388114929\n", "epoch:3, step:70, loss:0.4466826319694519\n", "epoch:3, step:80, loss:0.45244288444519043\n", "epoch:3, step:90, loss:0.41808539628982544\n", "epoch:3, step:100, loss:0.44330015778541565\n", " precision recall f1-score support\n", "\n", " 0 0.73 0.83 0.77 155\n", " 1 0.92 0.86 0.89 345\n", "\n", " accuracy 0.85 500\n", " macro avg 0.82 0.84 0.83 500\n", "weighted avg 0.86 0.85 0.85 500\n", "\n", "准确率: 0.85\n", "AUC: 0.9206545114539505\n", "saved model: ./model/bert_dnn_3.model\n", "epoch:4, step:10, loss:0.39769938588142395\n", "epoch:4, step:20, loss:0.4465697407722473\n", "epoch:4, step:30, loss:0.4216257929801941\n", "epoch:4, step:40, loss:0.41328248381614685\n", "epoch:4, step:50, loss:0.41364049911499023\n", "epoch:4, step:60, loss:0.4332212507724762\n", "epoch:4, step:70, loss:0.4280005395412445\n", "epoch:4, step:80, loss:0.41606149077415466\n", "epoch:4, step:90, loss:0.43310579657554626\n", "epoch:4, step:100, loss:0.4076871871948242\n", " precision recall f1-score support\n", "\n", " 0 0.76 0.80 0.78 155\n", " 1 0.91 0.88 0.90 345\n", "\n", " accuracy 0.86 500\n", " macro avg 0.83 0.84 0.84 500\n", "weighted avg 0.86 0.86 0.86 500\n", "\n", "准确率: 0.858\n", "AUC: 0.9222814399251986\n", "saved model: ./model/bert_dnn_4.model\n", "epoch:5, step:10, loss:0.39923620223999023\n", "epoch:5, step:20, loss:0.4110904633998871\n", "epoch:5, step:30, loss:0.4446052610874176\n", "epoch:5, step:40, loss:0.4050986170768738\n", "epoch:5, step:50, loss:0.41362982988357544\n", "epoch:5, step:60, loss:0.3961515724658966\n", "epoch:5, step:80, loss:0.43208274245262146\n", "epoch:5, step:90, loss:0.4123595356941223\n", "epoch:5, step:100, loss:0.4114747643470764\n", " precision recall f1-score support\n", "\n", " 0 0.75 0.81 0.78 155\n", " 1 0.91 0.88 0.90 345\n", "\n", " accuracy 0.86 500\n", " macro avg 0.83 0.85 0.84 500\n", "weighted avg 0.86 0.86 0.86 500\n", "\n", "准确率: 0.86\n", "AUC: 0.9251238896680692\n", "saved model: ./model/bert_dnn_5.model\n", "epoch:6, step:10, loss:0.4047953188419342\n", "epoch:6, step:20, loss:0.41434162855148315\n", "epoch:6, step:30, loss:0.4052816927433014\n", "epoch:6, step:40, loss:0.3726503849029541\n", "epoch:6, step:50, loss:0.4252064824104309\n", "epoch:6, step:60, loss:0.411870539188385\n", "epoch:6, step:70, loss:0.43613123893737793\n", "epoch:6, step:80, loss:0.4038943350315094\n", "epoch:6, step:90, loss:0.40738430619239807\n", "epoch:6, step:100, loss:0.41697797179222107\n", " precision recall f1-score support\n", "\n", " 0 0.76 0.82 0.79 155\n", " 1 0.92 0.88 0.90 345\n", "\n", " accuracy 0.86 500\n", " macro avg 0.84 0.85 0.84 500\n", "weighted avg 0.87 0.86 0.87 500\n", "\n", "准确率: 0.864\n", "AUC: 0.9266947171575501\n", "saved model: ./model/bert_dnn_6.model\n", "epoch:7, step:10, loss:0.4255238175392151\n", "epoch:7, step:20, loss:0.3951468765735626\n", "epoch:7, step:30, loss:0.41892367601394653\n", "epoch:7, step:40, loss:0.40587490797042847\n", "epoch:7, step:50, loss:0.3918803036212921\n", "epoch:7, step:60, loss:0.43665409088134766\n", "epoch:7, step:70, loss:0.4085603654384613\n", "epoch:7, step:80, loss:0.3877314627170563\n", "epoch:7, step:90, loss:0.3680875301361084\n", "epoch:7, step:100, loss:0.4211949408054352\n", " precision recall f1-score support\n", "\n", " 0 0.74 0.85 0.79 155\n", " 1 0.93 0.87 0.90 345\n", "\n", " accuracy 0.86 500\n", " macro avg 0.83 0.86 0.84 500\n", "weighted avg 0.87 0.86 0.86 500\n", "\n", "准确率: 0.86\n", "AUC: 0.9282094436652641\n", "saved model: ./model/bert_dnn_7.model\n", "epoch:8, step:10, loss:0.3657851815223694\n", "epoch:8, step:20, loss:0.3944622576236725\n", "epoch:8, step:30, loss:0.40657711029052734\n", "epoch:8, step:40, loss:0.3935934901237488\n", "epoch:8, step:50, loss:0.4171984791755676\n", "epoch:8, step:60, loss:0.4169773459434509\n", "epoch:8, step:70, loss:0.4021885395050049\n", "epoch:8, step:80, loss:0.4106557369232178\n", "epoch:8, step:100, loss:0.4116268754005432\n", " precision recall f1-score support\n", "\n", " 0 0.80 0.78 0.79 155\n", " 1 0.90 0.91 0.91 345\n", "\n", " accuracy 0.87 500\n", " macro avg 0.85 0.85 0.85 500\n", "weighted avg 0.87 0.87 0.87 500\n", "\n", "准确率: 0.87\n", "AUC: 0.9288078541374474\n", "saved model: ./model/bert_dnn_8.model\n", "epoch:9, step:10, loss:0.4415532052516937\n", "epoch:9, step:20, loss:0.4093624949455261\n", "epoch:9, step:30, loss:0.3825526833534241\n", "epoch:9, step:40, loss:0.3692132532596588\n", "epoch:9, step:50, loss:0.39409342408180237\n", "epoch:9, step:60, loss:0.40440621972084045\n", "epoch:9, step:70, loss:0.3859332203865051\n", "epoch:9, step:80, loss:0.40987101197242737\n", "epoch:9, step:90, loss:0.4061252176761627\n", "epoch:9, step:100, loss:0.4131951332092285\n", " precision recall f1-score support\n", "\n", " 0 0.75 0.84 0.79 155\n", " 1 0.92 0.88 0.90 345\n", "\n", " accuracy 0.86 500\n", " macro avg 0.84 0.86 0.85 500\n", "weighted avg 0.87 0.86 0.87 500\n", "\n", "准确率: 0.864\n", "AUC: 0.9293501636278635\n", "saved model: ./model/bert_dnn_9.model\n", "epoch:10, step:10, loss:0.40611615777015686\n", "epoch:10, step:20, loss:0.42403316497802734\n", "epoch:10, step:30, loss:0.3972412943840027\n", "epoch:10, step:40, loss:0.4144269526004791\n", "epoch:10, step:50, loss:0.37967294454574585\n", "epoch:10, step:60, loss:0.3992181420326233\n", "epoch:10, step:70, loss:0.3896545469760895\n", "epoch:10, step:80, loss:0.39779797196388245\n", "epoch:10, step:90, loss:0.38316115736961365\n", "epoch:10, step:100, loss:0.4042983055114746\n", " precision recall f1-score support\n", "\n", " 0 0.76 0.82 0.79 155\n", " 1 0.92 0.88 0.90 345\n", "\n", " accuracy 0.86 500\n", " macro avg 0.84 0.85 0.84 500\n", "weighted avg 0.87 0.86 0.86 500\n", "\n", "准确率: 0.862\n", "AUC: 0.9303973819541842\n", "saved model: ./model/bert_dnn_10.model\n" ] } ], "source": [ "# 迭代训练\n", "for epoch in range(num_epoches):\n", " total_loss = 0\n", " for i, (words, labels) in enumerate(train_loader):\n", " tokens = tokenizer(words, padding=True)\n", " input_ids = torch.tensor(tokens[\"input_ids\"]).to(device)\n", " attention_mask = torch.tensor(tokens[\"attention_mask\"]).to(device)\n", " labels = labels.float().to(device)\n", " with torch.no_grad():\n", " last_hidden_states = bert(input_ids, attention_mask=attention_mask)\n", " bert_output = last_hidden_states[0][:, 0]\n", " optimizer.zero_grad() # 梯度清零\n", " outputs = net(bert_output) # 前向传播\n", " logits = outputs.view(-1) # 将输出展平\n", " loss = criterion(logits, labels) # loss计算\n", " total_loss += loss\n", " loss.backward() # 反向传播,计算梯度\n", " optimizer.step() # 梯度更新\n", " if (i+1) % 10 == 0:\n", " print(\"epoch:{}, step:{}, loss:{}\".format(epoch+1, i+1, total_loss/10))\n", " total_loss = 0\n", " \n", " # learning_rate decay\n", " scheduler.step()\n", " \n", " # test\n", " test()\n", " \n", " # save model\n", " model_path = \"./model/bert_dnn_{}.model\".format(epoch+1)\n", " torch.save(net, model_path)\n", " print(\"saved model: \", model_path)" ] }, { "cell_type": "markdown", "metadata": { "cell_id": 23 }, "source": [ "### 手动输入句子,判断情感倾向(1正/0负)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "cell_id": 38 }, "outputs": [], "source": [ "net = torch.load(\"./model/bert_dnn_8.model\") # 训练过程中的巅峰时刻" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "cell_id": 37 }, "outputs": [ { "data": { "text/plain": [ "tensor([[0.9007],\n", " [0.2211]], grad_fn=)" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "s = [\"华丽繁荣的城市、充满回忆的小镇、郁郁葱葱的山谷...\", \"突然就觉得人间不值得\"]\n", "tokens = tokenizer(s, padding=True)\n", "input_ids = torch.tensor(tokens[\"input_ids\"])\n", "attention_mask = torch.tensor(tokens[\"attention_mask\"])\n", "last_hidden_states = bert(input_ids, attention_mask=attention_mask)\n", "bert_output = last_hidden_states[0][:, 0]\n", "outputs = net(bert_output)\n", "outputs" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "cell_id": 27, "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "tensor([[0.9735],\n", " [0.9882]], grad_fn=)" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "s = [\"今天天气真好\", \"今天天气特别特别棒\"]\n", "tokens = tokenizer(s, padding=True)\n", "input_ids = torch.tensor(tokens[\"input_ids\"])\n", "attention_mask = torch.tensor(tokens[\"attention_mask\"])\n", "last_hidden_states = bert(input_ids, attention_mask=attention_mask)\n", "bert_output = last_hidden_states[0][:, 0]\n", "outputs = net(bert_output)\n", "outputs" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cell_id": 32 }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" }, "max_cell_id": 45 }, "nbformat": 4, "nbformat_minor": 5 }