Add the base class, configuration file, and training script for the Qwen3 sentiment analysis model, with support for various model sizes (0.6B, 4B, 8B).
This commit is contained in:
@@ -1,5 +1,3 @@
|
||||
*.js linguist-language=python
|
||||
|
||||
*.ipynb binary
|
||||
|
||||
* text=auto
|
||||
@@ -184,6 +184,7 @@ WeiboSentiment_LLM/models/
|
||||
WeiboSentiment_Finetuned/BertChinese-Lora/model/
|
||||
WeiboMultilingualSentiment/model/
|
||||
WeiboSentiment_MachineLearning/model/chinese_wwm_pytorch/
|
||||
WeiboSentiment_SmallQwen/models/
|
||||
|
||||
# LoRA 和 Adapter 权重
|
||||
*/adapter_model.safetensors
|
||||
|
||||
@@ -0,0 +1,171 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Qwen3模型基础类,统一接口
|
||||
"""
|
||||
import os
|
||||
import pickle
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Tuple, Dict, Any
|
||||
import pandas as pd
|
||||
from sklearn.metrics import accuracy_score, f1_score, classification_report
|
||||
from sklearn.model_selection import train_test_split
|
||||
|
||||
|
||||
class BaseQwenModel(ABC):
|
||||
"""Qwen3情感分析模型基类"""
|
||||
|
||||
def __init__(self, model_name: str):
|
||||
self.model_name = model_name
|
||||
self.model = None
|
||||
self.is_trained = False
|
||||
|
||||
@abstractmethod
|
||||
def train(self, train_data: List[Tuple[str, int]], **kwargs) -> None:
|
||||
"""训练模型"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def predict(self, texts: List[str]) -> List[int]:
|
||||
"""预测文本情感"""
|
||||
pass
|
||||
|
||||
def predict_single(self, text: str) -> Tuple[int, float]:
|
||||
"""预测单条文本的情感
|
||||
|
||||
Args:
|
||||
text: 待预测文本
|
||||
|
||||
Returns:
|
||||
(predicted_label, confidence)
|
||||
"""
|
||||
predictions = self.predict([text])
|
||||
return predictions[0], 0.0 # 默认置信度为0
|
||||
|
||||
def evaluate(self, test_data: List[Tuple[str, int]]) -> Dict[str, float]:
|
||||
"""评估模型性能"""
|
||||
if not self.is_trained:
|
||||
raise ValueError(f"模型 {self.model_name} 尚未训练,请先调用train方法")
|
||||
|
||||
texts = [item[0] for item in test_data]
|
||||
labels = [item[1] for item in test_data]
|
||||
|
||||
predictions = self.predict(texts)
|
||||
|
||||
accuracy = accuracy_score(labels, predictions)
|
||||
f1 = f1_score(labels, predictions, average='weighted')
|
||||
|
||||
print(f"\n{self.model_name} 模型评估结果:")
|
||||
print(f"准确率: {accuracy:.4f}")
|
||||
print(f"F1分数: {f1:.4f}")
|
||||
print("\n详细报告:")
|
||||
print(classification_report(labels, predictions))
|
||||
|
||||
return {
|
||||
'accuracy': accuracy,
|
||||
'f1_score': f1,
|
||||
'classification_report': classification_report(labels, predictions)
|
||||
}
|
||||
|
||||
@abstractmethod
|
||||
def save_model(self, model_path: str = None) -> None:
|
||||
"""保存模型到文件"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_model(self, model_path: str) -> None:
|
||||
"""从文件加载模型"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def load_data(train_path: str = None, test_path: str = None, csv_path: str = 'dataset/weibo_senti_100k.csv') -> Tuple[List[Tuple[str, int]], List[Tuple[str, int]]]:
|
||||
"""加载训练和测试数据
|
||||
|
||||
Args:
|
||||
train_path: 训练数据txt文件路径(可选)
|
||||
test_path: 测试数据txt文件路径(可选)
|
||||
csv_path: CSV数据文件路径(默认使用)
|
||||
"""
|
||||
|
||||
# 优先尝试使用CSV文件
|
||||
if os.path.exists(csv_path):
|
||||
print(f"从CSV文件加载数据: {csv_path}")
|
||||
df = pd.read_csv(csv_path)
|
||||
|
||||
# 检查数据格式
|
||||
if 'review' in df.columns and 'label' in df.columns:
|
||||
# 将DataFrame转换为元组列表
|
||||
data = [(row['review'], row['label']) for _, row in df.iterrows()]
|
||||
|
||||
# 分割训练和测试数据,固定测试集为5000条
|
||||
total_samples = len(data)
|
||||
if total_samples > 5000:
|
||||
test_size = 5000
|
||||
train_data, test_data = train_test_split(
|
||||
data,
|
||||
test_size=test_size,
|
||||
random_state=42,
|
||||
stratify=[label for _, label in data]
|
||||
)
|
||||
else:
|
||||
# 如果总数据不足5000条,使用20%作为测试集
|
||||
train_data, test_data = train_test_split(
|
||||
data,
|
||||
test_size=0.2,
|
||||
random_state=42,
|
||||
stratify=[label for _, label in data]
|
||||
)
|
||||
|
||||
print(f"训练数据量: {len(train_data)}")
|
||||
print(f"测试数据量: {len(test_data)}")
|
||||
|
||||
return train_data, test_data
|
||||
else:
|
||||
print(f"CSV文件格式不正确,缺少'review'或'label'列")
|
||||
|
||||
# 如果CSV不存在,尝试使用txt文件
|
||||
elif train_path and test_path and os.path.exists(train_path) and os.path.exists(test_path):
|
||||
def load_corpus(path):
|
||||
data = []
|
||||
with open(path, "r", encoding="utf8") as f:
|
||||
for line in f:
|
||||
parts = line.strip().split("\t")
|
||||
if len(parts) >= 2:
|
||||
content = parts[0]
|
||||
sentiment = int(parts[1])
|
||||
data.append((content, sentiment))
|
||||
return data
|
||||
|
||||
print("从txt文件加载训练数据...")
|
||||
train_data = load_corpus(train_path)
|
||||
print(f"训练数据量: {len(train_data)}")
|
||||
|
||||
print("从txt文件加载测试数据...")
|
||||
test_data = load_corpus(test_path)
|
||||
print(f"测试数据量: {len(test_data)}")
|
||||
|
||||
return train_data, test_data
|
||||
|
||||
else:
|
||||
# 如果都没有,提供样例数据创建指导
|
||||
print("未找到数据文件!")
|
||||
print("请确保以下文件之一存在:")
|
||||
print(f"1. CSV文件: {csv_path}")
|
||||
print(f"2. txt文件: {train_path} 和 {test_path}")
|
||||
print("\n数据格式要求:")
|
||||
print("CSV文件: 包含'review'和'label'列")
|
||||
print("txt文件: 每行格式为'文本内容\\t标签'")
|
||||
|
||||
# 创建样例数据
|
||||
sample_data = [
|
||||
("今天天气真好,心情很棒!", 1),
|
||||
("这部电影太无聊了", 0),
|
||||
("非常喜欢这个产品", 1),
|
||||
("服务态度很差", 0),
|
||||
("质量不错,值得推荐", 1)
|
||||
]
|
||||
|
||||
print("使用样例数据进行演示...")
|
||||
train_data = sample_data * 20 # 扩充样例数据
|
||||
test_data = sample_data * 5
|
||||
|
||||
return train_data, test_data
|
||||
@@ -0,0 +1,280 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# weibo_senti_100k 说明\n",
|
||||
"0. **下载地址:** [百度网盘](https://pan.baidu.com/s/1DoQbki3YwqkuwQUOj64R_g)\n",
|
||||
"1. **数据概览:** 10 万多条,带情感标注 新浪微博,正负向评论约各 5 万条\n",
|
||||
"2. **推荐实验:** 情感/观点/评论 倾向性分析\n",
|
||||
"2. **数据来源:** [新浪微博](https://weibo.com/)\n",
|
||||
"3. **原数据集:** [新浪微博,情感分析标记语料共12万条](https://download.csdn.net/download/weixin_38442818/10214750),网上搜集,具体作者、来源不详\n",
|
||||
"4. **加工处理:**\n",
|
||||
" 1. 将原来的 2 份文档,整合成 1 份 csv 文件\n",
|
||||
" 2. 编码统一为 UTF-8\n",
|
||||
" 3. 去重"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import pandas as pd"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"path = 'weibo_senti_100k_文件夹_所在_路径'"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# 1. weibo_senti_100k.csv"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 加载数据"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"评论数目(总体):119988\n",
|
||||
"评论数目(正向):59993\n",
|
||||
"评论数目(负向):59995\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"pd_all = pd.read_csv(path + 'weibo_senti_100k.csv')\n",
|
||||
"\n",
|
||||
"print('评论数目(总体):%d' % pd_all.shape[0])\n",
|
||||
"print('评论数目(正向):%d' % pd_all[pd_all.label==1].shape[0])\n",
|
||||
"print('评论数目(负向):%d' % pd_all[pd_all.label==0].shape[0])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 字段说明\n",
|
||||
"\n",
|
||||
"| 字段 | 说明 |\n",
|
||||
"| ---- | ---- |\n",
|
||||
"| label | 1 表示正向评论,0 表示负向评论 |\n",
|
||||
"| review | 微博内容 |"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<div>\n",
|
||||
"<style scoped>\n",
|
||||
" .dataframe tbody tr th:only-of-type {\n",
|
||||
" vertical-align: middle;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe tbody tr th {\n",
|
||||
" vertical-align: top;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe thead th {\n",
|
||||
" text-align: right;\n",
|
||||
" }\n",
|
||||
"</style>\n",
|
||||
"<table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: right;\">\n",
|
||||
" <th></th>\n",
|
||||
" <th>label</th>\n",
|
||||
" <th>review</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>62050</th>\n",
|
||||
" <td>0</td>\n",
|
||||
" <td>太过分了@Rexzhenghao //@Janie_Zhang:招行最近负面新闻越来越多呀...</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>68263</th>\n",
|
||||
" <td>0</td>\n",
|
||||
" <td>希望你?得好?我本"?肥血?史"[晕][哈哈]@Pete三姑父</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>81472</th>\n",
|
||||
" <td>0</td>\n",
|
||||
" <td>有点想参加????[偷?]想安排下时间再决定[抓狂]//@黑晶晶crystal: @细腿大羽...</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>42021</th>\n",
|
||||
" <td>1</td>\n",
|
||||
" <td>[给力]感谢所有支持雯婕的芝麻![爱你]</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>7777</th>\n",
|
||||
" <td>1</td>\n",
|
||||
" <td>2013最后一天,在新加坡开心度过,向所有的朋友们问声:新年快乐!2014年,我们会更好[调...</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>100399</th>\n",
|
||||
" <td>0</td>\n",
|
||||
" <td>大中午出门办事找错路,曝晒中。要多杯具有多杯具。[泪][泪][汗]</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>82398</th>\n",
|
||||
" <td>0</td>\n",
|
||||
" <td>马航还会否认吗?到底在隐瞒啥呢?[抓狂]//@头条新闻: 转发微博</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>106423</th>\n",
|
||||
" <td>0</td>\n",
|
||||
" <td>克罗地亚球迷很爱放烟火!球又没进,就硝烟四起。[晕]</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>24798</th>\n",
|
||||
" <td>1</td>\n",
|
||||
" <td>[抱抱]福芦 TangRoulou 吉祥书 8.8折优惠 >>> http://t.cn/z...</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>6598</th>\n",
|
||||
" <td>1</td>\n",
|
||||
" <td>回复@钱旭明QXM:[嘻嘻][嘻嘻] //@钱旭明QXM:杨大哥[good][good][g...</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>53920</th>\n",
|
||||
" <td>1</td>\n",
|
||||
" <td>人家这脸长的!!!!!![哈哈]</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>15587</th>\n",
|
||||
" <td>1</td>\n",
|
||||
" <td>这个价不算高,和一天内训相比相差无几。。[哈哈]//@博通传媒v: 6个月!一个月工资1万,...</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>101237</th>\n",
|
||||
" <td>0</td>\n",
|
||||
" <td>终于收工啦,脚丫子快冻掉了[泪][泪][泪]</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>82449</th>\n",
|
||||
" <td>0</td>\n",
|
||||
" <td>我决定从今天开始我想吃什么就去吃什么,一个人吃也无所谓,重点是不要因为别人的意见委屈了自己[...</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>32537</th>\n",
|
||||
" <td>1</td>\n",
|
||||
" <td>飘雪的北京 需要双份早餐.......//@美食天下: [哈哈]//@王淼Margay: 屁...</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>10630</th>\n",
|
||||
" <td>1</td>\n",
|
||||
" <td>[耶],这个太赞了,生活大爆炸第六季马上要出啦[鼓掌] //@-郑瑜-:这个不错 //@经典...</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>85130</th>\n",
|
||||
" <td>0</td>\n",
|
||||
" <td>刚追完#倾世皇妃#,#千山暮雪#又紧随其后,网速和更新速度都太不给力,尽管我看过原著,还是焦...</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>105956</th>\n",
|
||||
" <td>0</td>\n",
|
||||
" <td>晚上看金二胖?察前?,推出的火炮基座?糟了,可以PK了[泪] //@艾米粒er: //@wi...</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>72391</th>\n",
|
||||
" <td>0</td>\n",
|
||||
" <td>必须把中国足球的伟大,用我的职业演说出来 //@袁腾飞:[泪]</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>10761</th>\n",
|
||||
" <td>1</td>\n",
|
||||
" <td>[鼓掌] //@宁波香格里拉大酒店: 小编来答疑,周五晚惊艳全场的树根蛋糕到底有多长?蛋糕全...</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
"</div>"
|
||||
],
|
||||
"text/plain": [
|
||||
" label review\n",
|
||||
"62050 0 太过分了@Rexzhenghao //@Janie_Zhang:招行最近负面新闻越来越多呀...\n",
|
||||
"68263 0 希望你?得好?我本"?肥血?史"[晕][哈哈]@Pete三姑父\n",
|
||||
"81472 0 有点想参加????[偷?]想安排下时间再决定[抓狂]//@黑晶晶crystal: @细腿大羽...\n",
|
||||
"42021 1 [给力]感谢所有支持雯婕的芝麻![爱你]\n",
|
||||
"7777 1 2013最后一天,在新加坡开心度过,向所有的朋友们问声:新年快乐!2014年,我们会更好[调...\n",
|
||||
"100399 0 大中午出门办事找错路,曝晒中。要多杯具有多杯具。[泪][泪][汗]\n",
|
||||
"82398 0 马航还会否认吗?到底在隐瞒啥呢?[抓狂]//@头条新闻: 转发微博\n",
|
||||
"106423 0 克罗地亚球迷很爱放烟火!球又没进,就硝烟四起。[晕]\n",
|
||||
"24798 1 [抱抱]福芦 TangRoulou 吉祥书 8.8折优惠 >>> http://t.cn/z...\n",
|
||||
"6598 1 回复@钱旭明QXM:[嘻嘻][嘻嘻] //@钱旭明QXM:杨大哥[good][good][g...\n",
|
||||
"53920 1 人家这脸长的!!!!!![哈哈]\n",
|
||||
"15587 1 这个价不算高,和一天内训相比相差无几。。[哈哈]//@博通传媒v: 6个月!一个月工资1万,...\n",
|
||||
"101237 0 终于收工啦,脚丫子快冻掉了[泪][泪][泪]\n",
|
||||
"82449 0 我决定从今天开始我想吃什么就去吃什么,一个人吃也无所谓,重点是不要因为别人的意见委屈了自己[...\n",
|
||||
"32537 1 飘雪的北京 需要双份早餐.......//@美食天下: [哈哈]//@王淼Margay: 屁...\n",
|
||||
"10630 1 [耶],这个太赞了,生活大爆炸第六季马上要出啦[鼓掌] //@-郑瑜-:这个不错 //@经典...\n",
|
||||
"85130 0 刚追完#倾世皇妃#,#千山暮雪#又紧随其后,网速和更新速度都太不给力,尽管我看过原著,还是焦...\n",
|
||||
"105956 0 晚上看金二胖?察前?,推出的火炮基座?糟了,可以PK了[泪] //@艾米粒er: //@wi...\n",
|
||||
"72391 0 必须把中国足球的伟大,用我的职业演说出来 //@袁腾飞:[泪]\n",
|
||||
"10761 1 [鼓掌] //@宁波香格里拉大酒店: 小编来答疑,周五晚惊艳全场的树根蛋糕到底有多长?蛋糕全..."
|
||||
]
|
||||
},
|
||||
"execution_count": 16,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"pd_all.sample(20)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.6.4"
|
||||
},
|
||||
"widgets": {
|
||||
"state": {},
|
||||
"version": "1.1.2"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,53 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Qwen3模型配置文件
|
||||
定义不同规模的模型参数和配置
|
||||
"""
|
||||
|
||||
# Qwen3模型配置
|
||||
QWEN3_MODELS = {
|
||||
"0.6B": {
|
||||
"base_model": "Qwen/Qwen3-0.6B",
|
||||
"embedding_model": "Qwen/Qwen3-Embedding-0.6B",
|
||||
"embedding_dim": 1024,
|
||||
"max_length": 32768,
|
||||
"recommended_batch_size": 32,
|
||||
"recommended_lr": 1e-3,
|
||||
"lora_r": 16,
|
||||
"lora_alpha": 32
|
||||
},
|
||||
"4B": {
|
||||
"base_model": "Qwen/Qwen3-4B",
|
||||
"embedding_model": "Qwen/Qwen3-Embedding-4B",
|
||||
"embedding_dim": 2560,
|
||||
"max_length": 32768,
|
||||
"recommended_batch_size": 16,
|
||||
"recommended_lr": 5e-4,
|
||||
"lora_r": 32,
|
||||
"lora_alpha": 64
|
||||
},
|
||||
"8B": {
|
||||
"base_model": "Qwen/Qwen3-8B",
|
||||
"embedding_model": "Qwen/Qwen3-Embedding-8B",
|
||||
"embedding_dim": 4096,
|
||||
"max_length": 32768,
|
||||
"recommended_batch_size": 8,
|
||||
"recommended_lr": 2e-4,
|
||||
"lora_r": 64,
|
||||
"lora_alpha": 128
|
||||
}
|
||||
}
|
||||
|
||||
# 模型文件路径配置
|
||||
MODEL_PATHS = {
|
||||
"embedding": {
|
||||
"0.6B": "./models/qwen3_embedding_0.6b_sentiment.pth",
|
||||
"4B": "./models/qwen3_embedding_4b_sentiment.pth",
|
||||
"8B": "./models/qwen3_embedding_8b_sentiment.pth"
|
||||
},
|
||||
"lora": {
|
||||
"0.6B": "./models/qwen3_lora_0.6b_final",
|
||||
"4B": "./models/qwen3_lora_4b_final",
|
||||
"8B": "./models/qwen3_lora_8b_final"
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,373 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Qwen3微博情感分析统一预测接口
|
||||
支持0.6B、4B、8B三种规格的Embedding和LoRA模型
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
import torch
|
||||
from typing import List, Dict, Tuple, Any
|
||||
|
||||
# 添加当前目录到路径
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from models_config import MODEL_CONFIGS, MODEL_PATHS
|
||||
from qwen3_embedding_universal import Qwen3EmbeddingUniversal
|
||||
from qwen3_lora_universal import Qwen3LoRAUniversal
|
||||
|
||||
|
||||
class Qwen3UniversalPredictor:
|
||||
"""Qwen3统一预测器"""
|
||||
|
||||
def __init__(self):
|
||||
self.models = {} # 存储已加载的模型 {model_key: {model: obj, display_name: str}}
|
||||
|
||||
def _get_model_key(self, model_type: str, model_size: str) -> str:
|
||||
"""生成模型键值"""
|
||||
return f"{model_type}_{model_size}"
|
||||
|
||||
def load_model(self, model_type: str, model_size: str) -> None:
|
||||
"""加载指定的模型"""
|
||||
if model_type not in ['embedding', 'lora']:
|
||||
raise ValueError(f"不支持的模型类型: {model_type}")
|
||||
if model_size not in ['0.6B', '4B', '8B']:
|
||||
raise ValueError(f"不支持的模型大小: {model_size}")
|
||||
|
||||
model_path = MODEL_PATHS[model_type][model_size]
|
||||
if not os.path.exists(model_path):
|
||||
print(f"模型文件不存在: {model_path}")
|
||||
return
|
||||
|
||||
model_key = self._get_model_key(model_type, model_size)
|
||||
print(f"加载 {model_type.upper()}-{model_size} 模型...")
|
||||
|
||||
try:
|
||||
if model_type == 'embedding':
|
||||
model = Qwen3EmbeddingUniversal(model_size)
|
||||
model.load_model(model_path)
|
||||
else: # lora
|
||||
model = Qwen3LoRAUniversal(model_size)
|
||||
model.load_model(model_path)
|
||||
|
||||
self.models[model_key] = {
|
||||
'model': model,
|
||||
'display_name': f"Qwen3-{model_type.title()}-{model_size}"
|
||||
}
|
||||
print(f"{model_type.upper()}-{model_size} 模型加载成功")
|
||||
|
||||
except Exception as e:
|
||||
print(f"加载 {model_type.upper()}-{model_size} 模型失败: {e}")
|
||||
|
||||
def load_all_models(self, model_dir: str = './models') -> None:
|
||||
"""加载所有可用的模型"""
|
||||
print("开始加载所有可用的Qwen3模型...")
|
||||
|
||||
loaded_count = 0
|
||||
for model_type in ['embedding', 'lora']:
|
||||
for model_size in ['0.6B', '4B', '8B']:
|
||||
try:
|
||||
self.load_model(model_type, model_size)
|
||||
loaded_count += 1
|
||||
except Exception as e:
|
||||
print(f"跳过 {model_type}-{model_size}: {e}")
|
||||
|
||||
print(f"\n已加载 {loaded_count} 个模型")
|
||||
self._print_loaded_models()
|
||||
|
||||
def load_specific_models(self, model_configs: List[Tuple[str, str]]) -> None:
|
||||
"""加载指定的模型配置
|
||||
Args:
|
||||
model_configs: [(model_type, model_size), ...] 的列表
|
||||
"""
|
||||
print("加载指定的Qwen3模型...")
|
||||
|
||||
for model_type, model_size in model_configs:
|
||||
try:
|
||||
self.load_model(model_type, model_size)
|
||||
except Exception as e:
|
||||
print(f"跳过 {model_type}-{model_size}: {e}")
|
||||
|
||||
print(f"\n已加载 {len(self.models)} 个模型")
|
||||
self._print_loaded_models()
|
||||
|
||||
def _print_loaded_models(self):
|
||||
"""打印已加载的模型列表"""
|
||||
if self.models:
|
||||
print("已加载模型:")
|
||||
for model_info in self.models.values():
|
||||
print(f" - {model_info['display_name']}")
|
||||
else:
|
||||
print("没有成功加载任何模型")
|
||||
|
||||
def predict_single(self, text: str, model_key: str = None) -> Dict[str, Tuple[int, float]]:
|
||||
"""单文本预测
|
||||
Args:
|
||||
text: 要预测的文本
|
||||
model_key: 指定模型键值,None表示使用所有模型
|
||||
Returns:
|
||||
{model_name: (prediction, confidence), ...}
|
||||
"""
|
||||
results = {}
|
||||
|
||||
if model_key and model_key in self.models:
|
||||
# 使用指定模型
|
||||
model_info = self.models[model_key]
|
||||
try:
|
||||
prediction, confidence = model_info['model'].predict_single(text)
|
||||
results[model_info['display_name']] = (prediction, confidence)
|
||||
except Exception as e:
|
||||
print(f"模型 {model_info['display_name']} 预测失败: {e}")
|
||||
results[model_info['display_name']] = (0, 0.0)
|
||||
else:
|
||||
# 使用所有模型
|
||||
for model_info in self.models.values():
|
||||
try:
|
||||
prediction, confidence = model_info['model'].predict_single(text)
|
||||
results[model_info['display_name']] = (prediction, confidence)
|
||||
except Exception as e:
|
||||
print(f"模型 {model_info['display_name']} 预测失败: {e}")
|
||||
results[model_info['display_name']] = (0, 0.0)
|
||||
|
||||
return results
|
||||
|
||||
def predict_batch(self, texts: List[str]) -> Dict[str, List[int]]:
|
||||
"""批量预测"""
|
||||
results = {}
|
||||
|
||||
for model_info in self.models.values():
|
||||
try:
|
||||
predictions = model_info['model'].predict(texts)
|
||||
results[model_info['display_name']] = predictions
|
||||
except Exception as e:
|
||||
print(f"模型 {model_info['display_name']} 预测失败: {e}")
|
||||
results[model_info['display_name']] = [0] * len(texts)
|
||||
|
||||
return results
|
||||
|
||||
def ensemble_predict(self, text: str) -> Tuple[int, float]:
|
||||
"""集成预测"""
|
||||
if len(self.models) < 2:
|
||||
raise ValueError("集成预测需要至少2个模型")
|
||||
|
||||
results = self.predict_single(text)
|
||||
|
||||
# 加权平均(这里使用简单平均,可以根据模型性能调整权重)
|
||||
total_weight = 0
|
||||
weighted_prob = 0
|
||||
|
||||
for model_name, (pred, conf) in results.items():
|
||||
if conf > 0: # 只考虑有效预测
|
||||
prob = conf if pred == 1 else 1 - conf
|
||||
weighted_prob += prob
|
||||
total_weight += 1
|
||||
|
||||
if total_weight == 0:
|
||||
return 0, 0.5
|
||||
|
||||
final_prob = weighted_prob / total_weight
|
||||
final_pred = int(final_prob > 0.5)
|
||||
final_conf = final_prob if final_pred == 1 else 1 - final_prob
|
||||
|
||||
return final_pred, final_conf
|
||||
|
||||
def _select_and_load_model(self):
|
||||
"""让用户选择并加载模型"""
|
||||
print("Qwen3微博情感分析预测系统")
|
||||
print("="*40)
|
||||
print("请选择要使用的模型:")
|
||||
print("\n方法选择:")
|
||||
print(" 1. Embedding + 分类头 (推理快速,显存占用少)")
|
||||
print(" 2. LoRA微调 (效果更好,显存占用较多)")
|
||||
|
||||
method_choice = None
|
||||
while method_choice not in ['1', '2']:
|
||||
method_choice = input("\n请选择方法 (1/2): ").strip()
|
||||
if method_choice not in ['1', '2']:
|
||||
print("无效选择,请输入 1 或 2")
|
||||
|
||||
method_type = "embedding" if method_choice == '1' else "lora"
|
||||
method_name = "Embedding + 分类头" if method_choice == '1' else "LoRA微调"
|
||||
|
||||
print(f"\n已选择: {method_name}")
|
||||
print("\n模型大小选择:")
|
||||
print(" 1. 0.6B - 轻量级,推理快速")
|
||||
print(" 2. 4B - 中等规模,性能均衡")
|
||||
print(" 3. 8B - 大规模,性能最佳")
|
||||
|
||||
size_choice = None
|
||||
while size_choice not in ['1', '2', '3']:
|
||||
size_choice = input("\n请选择模型大小 (1/2/3): ").strip()
|
||||
if size_choice not in ['1', '2', '3']:
|
||||
print("无效选择,请输入 1、2 或 3")
|
||||
|
||||
size_map = {'1': '0.6B', '2': '4B', '3': '8B'}
|
||||
model_size = size_map[size_choice]
|
||||
|
||||
print(f"已选择: Qwen3-{method_name}-{model_size}")
|
||||
print("正在加载模型...")
|
||||
|
||||
try:
|
||||
self.load_model(method_type, model_size)
|
||||
print(f"模型加载成功!")
|
||||
except Exception as e:
|
||||
print(f"模型加载失败: {e}")
|
||||
print("请检查模型文件是否存在,或先进行训练")
|
||||
|
||||
def interactive_predict(self):
|
||||
"""交互式预测模式"""
|
||||
if len(self.models) == 0:
|
||||
# 让用户选择要加载的模型
|
||||
self._select_and_load_model()
|
||||
if len(self.models) == 0:
|
||||
print("没有加载任何模型,退出预测")
|
||||
return
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("Qwen3微博情感分析预测系统")
|
||||
print("="*60)
|
||||
print("已加载模型:")
|
||||
for model_info in self.models.values():
|
||||
print(f" - {model_info['display_name']}")
|
||||
print("\n命令提示:")
|
||||
print(" 输入 'q' 退出程序")
|
||||
print(" 输入 'switch' 切换模型")
|
||||
print(" 输入 'models' 查看已加载模型")
|
||||
print(" 输入 'compare' 比较所有模型性能")
|
||||
print("-"*60)
|
||||
|
||||
while True:
|
||||
try:
|
||||
text = input("\n请输入要分析的微博内容: ").strip()
|
||||
|
||||
if text.lower() == 'q':
|
||||
print("感谢使用,再见!")
|
||||
break
|
||||
|
||||
if text.lower() == 'models':
|
||||
print("已加载模型:")
|
||||
for model_info in self.models.values():
|
||||
print(f" - {model_info['display_name']}")
|
||||
continue
|
||||
|
||||
if text.lower() == 'switch':
|
||||
print("切换模型...")
|
||||
self.models.clear() # 清空当前模型
|
||||
self._select_and_load_model()
|
||||
if len(self.models) > 0:
|
||||
print("模型切换成功!")
|
||||
for model_info in self.models.values():
|
||||
print(f" 当前模型: {model_info['display_name']}")
|
||||
continue
|
||||
|
||||
if text.lower() == 'compare':
|
||||
test_text = input("请输入要比较的文本: ")
|
||||
self._compare_models(test_text)
|
||||
continue
|
||||
|
||||
if not text:
|
||||
print("请输入有效内容")
|
||||
continue
|
||||
|
||||
# 预测
|
||||
results = self.predict_single(text)
|
||||
|
||||
print(f"\n原文: {text}")
|
||||
print("预测结果:")
|
||||
|
||||
# 按模型类型和大小排序显示
|
||||
sorted_results = sorted(results.items())
|
||||
for model_name, (pred, conf) in sorted_results:
|
||||
sentiment = "正面" if pred == 1 else "负面"
|
||||
print(f" {model_name:20}: {sentiment} (置信度: {conf:.4f})")
|
||||
|
||||
# 只显示单个模型的预测结果(不进行集成)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\n程序被中断,再见!")
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"预测过程中出现错误: {e}")
|
||||
|
||||
def _compare_models(self, text: str):
|
||||
"""比较不同模型的性能"""
|
||||
print(f"\n模型性能比较 - 文本: {text}")
|
||||
print("-" * 60)
|
||||
|
||||
results = self.predict_single(text)
|
||||
|
||||
embedding_models = []
|
||||
lora_models = []
|
||||
|
||||
for model_name, (pred, conf) in results.items():
|
||||
sentiment = "正面" if pred == 1 else "负面"
|
||||
if "Embedding" in model_name:
|
||||
embedding_models.append((model_name, sentiment, conf))
|
||||
elif "Lora" in model_name:
|
||||
lora_models.append((model_name, sentiment, conf))
|
||||
|
||||
if embedding_models:
|
||||
print("Embedding + 分类头方法:")
|
||||
for name, sentiment, conf in embedding_models:
|
||||
print(f" {name}: {sentiment} ({conf:.4f})")
|
||||
|
||||
if lora_models:
|
||||
print("LoRA微调方法:")
|
||||
for name, sentiment, conf in lora_models:
|
||||
print(f" {name}: {sentiment} ({conf:.4f})")
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
parser = argparse.ArgumentParser(description='Qwen3微博情感分析统一预测接口')
|
||||
parser.add_argument('--model_dir', type=str, default='./models',
|
||||
help='模型文件目录')
|
||||
parser.add_argument('--model_type', type=str, choices=['embedding', 'lora'],
|
||||
help='指定模型类型')
|
||||
parser.add_argument('--model_size', type=str, choices=['0.6B', '4B', '8B'],
|
||||
help='指定模型大小')
|
||||
parser.add_argument('--text', type=str,
|
||||
help='直接预测指定文本')
|
||||
parser.add_argument('--interactive', action='store_true', default=True,
|
||||
help='交互式预测模式(默认)')
|
||||
parser.add_argument('--ensemble', action='store_true',
|
||||
help='使用集成预测')
|
||||
parser.add_argument('--load_all', action='store_true',
|
||||
help='加载所有可用模型')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 创建预测器
|
||||
predictor = Qwen3UniversalPredictor()
|
||||
|
||||
# 加载模型
|
||||
if args.load_all:
|
||||
# 加载所有模型
|
||||
predictor.load_all_models(args.model_dir)
|
||||
elif args.model_type and args.model_size:
|
||||
# 加载指定模型
|
||||
predictor.load_model(args.model_type, args.model_size)
|
||||
# 如果没有指定模型,交互式模式会让用户选择
|
||||
|
||||
# 如果指定了文本,直接预测
|
||||
if args.text:
|
||||
if args.ensemble and len(predictor.models) > 1:
|
||||
pred, conf = predictor.ensemble_predict(args.text)
|
||||
sentiment = "正面" if pred == 1 else "负面"
|
||||
print(f"文本: {args.text}")
|
||||
print(f"集成预测: {sentiment} (置信度: {conf:.4f})")
|
||||
else:
|
||||
results = predictor.predict_single(args.text)
|
||||
print(f"文本: {args.text}")
|
||||
for model_name, (pred, conf) in results.items():
|
||||
sentiment = "正面" if pred == 1 else "负面"
|
||||
print(f"{model_name}: {sentiment} (置信度: {conf:.4f})")
|
||||
else:
|
||||
# 进入交互式模式
|
||||
predictor.interactive_predict()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,409 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Qwen3-Embedding通用训练脚本
|
||||
支持0.6B、4B、8B三种规模的模型
|
||||
"""
|
||||
import argparse
|
||||
import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from transformers import AutoTokenizer, AutoModel
|
||||
from typing import List, Tuple
|
||||
import warnings
|
||||
from tqdm import tqdm
|
||||
|
||||
from base_model import BaseQwenModel
|
||||
from models_config import QWEN3_MODELS, MODEL_PATHS
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
|
||||
class SentimentDataset(Dataset):
|
||||
"""情感分析数据集"""
|
||||
|
||||
def __init__(self, data: List[Tuple[str, int]], tokenizer, max_length=512):
|
||||
self.texts = [item[0] for item in data]
|
||||
self.labels = [item[1] for item in data]
|
||||
self.tokenizer = tokenizer
|
||||
self.max_length = max_length
|
||||
|
||||
def __len__(self):
|
||||
return len(self.texts)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
text = str(self.texts[idx])
|
||||
label = self.labels[idx]
|
||||
|
||||
encoding = self.tokenizer(
|
||||
text,
|
||||
max_length=self.max_length,
|
||||
padding='max_length',
|
||||
truncation=True,
|
||||
return_tensors='pt'
|
||||
)
|
||||
|
||||
return {
|
||||
'input_ids': encoding['input_ids'].flatten(),
|
||||
'attention_mask': encoding['attention_mask'].flatten(),
|
||||
'label': torch.tensor(label, dtype=torch.float)
|
||||
}
|
||||
|
||||
|
||||
class SentimentClassifier(nn.Module):
|
||||
"""情感分类器"""
|
||||
|
||||
def __init__(self, embedding_model, embedding_dim, hidden_dim=256):
|
||||
super(SentimentClassifier, self).__init__()
|
||||
self.embedding_model = embedding_model
|
||||
|
||||
# 冻结embedding模型参数
|
||||
for param in self.embedding_model.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
# 分类头
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Linear(embedding_dim, hidden_dim),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(hidden_dim, 1),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
def forward(self, input_ids, attention_mask):
|
||||
# 获取embedding
|
||||
with torch.no_grad():
|
||||
outputs = self.embedding_model(input_ids=input_ids, attention_mask=attention_mask)
|
||||
embeddings = outputs.last_hidden_state[:, 0, :]
|
||||
|
||||
# 通过分类头
|
||||
logits = self.classifier(embeddings)
|
||||
return logits.squeeze()
|
||||
|
||||
|
||||
class Qwen3EmbeddingUniversal(BaseQwenModel):
|
||||
"""通用Qwen3-Embedding模型"""
|
||||
|
||||
def __init__(self, model_size: str = "0.6B"):
|
||||
if model_size not in QWEN3_MODELS:
|
||||
raise ValueError(f"不支持的模型大小: {model_size}")
|
||||
|
||||
super().__init__(f"Qwen3-Embedding-{model_size}")
|
||||
self.model_size = model_size
|
||||
self.config = QWEN3_MODELS[model_size]
|
||||
self.model_name_hf = self.config["embedding_model"]
|
||||
self.embedding_dim = self.config["embedding_dim"]
|
||||
|
||||
self.tokenizer = None
|
||||
self.embedding_model = None
|
||||
self.classifier_model = None
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
def _load_embedding_model(self):
|
||||
"""加载Qwen3 Embedding模型"""
|
||||
print(f"加载{self.model_size}模型: {self.model_name_hf}")
|
||||
|
||||
try:
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_hf)
|
||||
self.embedding_model = AutoModel.from_pretrained(self.model_name_hf).to(self.device)
|
||||
print(f"{self.model_size}模型加载完成")
|
||||
|
||||
# 立即保存到本地缓存
|
||||
cache_dir = f"./models/qwen3-embedding-{self.model_size.lower()}"
|
||||
if not os.path.exists(cache_dir):
|
||||
print(f"保存模型到本地: {cache_dir}")
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
self.tokenizer.save_pretrained(cache_dir)
|
||||
self.embedding_model.save_pretrained(cache_dir)
|
||||
print(f"模型已保存到: {cache_dir}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"从Hugging Face加载失败: {e}")
|
||||
|
||||
# 尝试从本地缓存加载
|
||||
cache_dir = f"./models/qwen3-embedding-{self.model_size.lower()}"
|
||||
try:
|
||||
if os.path.exists(cache_dir):
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(cache_dir)
|
||||
self.embedding_model = AutoModel.from_pretrained(cache_dir).to(self.device)
|
||||
print(f"从本地缓存加载{self.model_size}模型成功")
|
||||
else:
|
||||
raise FileNotFoundError("本地缓存也不存在")
|
||||
|
||||
except Exception as e2:
|
||||
print(f"本地加载也失败: {e2}")
|
||||
print(f"正在下载{self.model_size}模型...")
|
||||
|
||||
# 创建缓存目录并下载
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_hf, cache_dir=cache_dir)
|
||||
self.embedding_model = AutoModel.from_pretrained(self.model_name_hf, cache_dir=cache_dir).to(self.device)
|
||||
|
||||
# 保存到本地
|
||||
self.tokenizer.save_pretrained(cache_dir)
|
||||
self.embedding_model.save_pretrained(cache_dir)
|
||||
print(f"{self.model_size}模型下载并保存到: {cache_dir}")
|
||||
|
||||
def train(self, train_data: List[Tuple[str, int]], **kwargs) -> None:
|
||||
"""训练模型"""
|
||||
print(f"开始训练 Qwen3-Embedding-{self.model_size} 模型...")
|
||||
|
||||
# 加载embedding模型
|
||||
self._load_embedding_model()
|
||||
|
||||
# 超参数(使用配置文件的推荐值或用户指定值)
|
||||
batch_size = kwargs.get('batch_size', self.config['recommended_batch_size'])
|
||||
learning_rate = kwargs.get('learning_rate', self.config['recommended_lr'])
|
||||
num_epochs = kwargs.get('num_epochs', 5)
|
||||
max_length = kwargs.get('max_length', 512)
|
||||
|
||||
print(f"超参数: batch_size={batch_size}, lr={learning_rate}, epochs={num_epochs}")
|
||||
print(f"嵌入维度: {self.embedding_dim}")
|
||||
|
||||
# 创建数据集
|
||||
train_dataset = SentimentDataset(train_data, self.tokenizer, max_length)
|
||||
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
||||
|
||||
# 创建分类器
|
||||
self.classifier_model = SentimentClassifier(
|
||||
self.embedding_model,
|
||||
self.embedding_dim
|
||||
).to(self.device)
|
||||
|
||||
# 损失函数和优化器
|
||||
criterion = nn.BCELoss()
|
||||
optimizer = torch.optim.Adam(self.classifier_model.classifier.parameters(), lr=learning_rate)
|
||||
|
||||
# 训练循环
|
||||
self.classifier_model.train()
|
||||
for epoch in range(num_epochs):
|
||||
total_loss = 0
|
||||
num_batches = 0
|
||||
|
||||
progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
|
||||
for batch in progress_bar:
|
||||
input_ids = batch['input_ids'].to(self.device)
|
||||
attention_mask = batch['attention_mask'].to(self.device)
|
||||
labels = batch['label'].to(self.device)
|
||||
|
||||
# 前向传播
|
||||
outputs = self.classifier_model(input_ids, attention_mask)
|
||||
loss = criterion(outputs, labels)
|
||||
|
||||
# 反向传播
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
total_loss += loss.item()
|
||||
num_batches += 1
|
||||
|
||||
progress_bar.set_postfix({'loss': total_loss / num_batches})
|
||||
|
||||
avg_loss = total_loss / num_batches
|
||||
print(f"Epoch [{epoch+1}/{num_epochs}], Average Loss: {avg_loss:.4f}")
|
||||
|
||||
self.model = self.classifier_model
|
||||
self.is_trained = True
|
||||
print(f"Qwen3-Embedding-{self.model_size} 模型训练完成!")
|
||||
|
||||
def predict(self, texts: List[str]) -> List[int]:
|
||||
"""预测文本情感"""
|
||||
if not self.is_trained:
|
||||
raise ValueError(f"模型 {self.model_name} 尚未训练")
|
||||
|
||||
predictions = []
|
||||
batch_size = 32
|
||||
|
||||
self.classifier_model.eval()
|
||||
with torch.no_grad():
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch_texts = texts[i:i+batch_size]
|
||||
|
||||
encodings = self.tokenizer(
|
||||
batch_texts,
|
||||
max_length=512,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
return_tensors='pt'
|
||||
)
|
||||
|
||||
input_ids = encodings['input_ids'].to(self.device)
|
||||
attention_mask = encodings['attention_mask'].to(self.device)
|
||||
|
||||
outputs = self.classifier_model(input_ids, attention_mask)
|
||||
preds = (outputs > 0.5).cpu().numpy()
|
||||
predictions.extend(preds.astype(int).tolist())
|
||||
|
||||
return predictions
|
||||
|
||||
def predict_single(self, text: str) -> Tuple[int, float]:
|
||||
"""预测单条文本的情感"""
|
||||
if not self.is_trained:
|
||||
raise ValueError(f"模型 {self.model_name} 尚未训练")
|
||||
|
||||
self.classifier_model.eval()
|
||||
with torch.no_grad():
|
||||
encoding = self.tokenizer(
|
||||
text,
|
||||
max_length=512,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
return_tensors='pt'
|
||||
)
|
||||
|
||||
input_ids = encoding['input_ids'].to(self.device)
|
||||
attention_mask = encoding['attention_mask'].to(self.device)
|
||||
|
||||
output = self.classifier_model(input_ids, attention_mask)
|
||||
prob = output.item()
|
||||
prediction = int(prob > 0.5)
|
||||
confidence = prob if prediction == 1 else 1 - prob
|
||||
|
||||
return prediction, confidence
|
||||
|
||||
def save_model(self, model_path: str = None) -> None:
|
||||
"""保存模型"""
|
||||
if not self.is_trained:
|
||||
raise ValueError(f"模型 {self.model_name} 尚未训练")
|
||||
|
||||
if model_path is None:
|
||||
model_path = MODEL_PATHS["embedding"][self.model_size]
|
||||
|
||||
os.makedirs(os.path.dirname(model_path), exist_ok=True)
|
||||
|
||||
model_data = {
|
||||
'classifier_state_dict': self.classifier_model.classifier.state_dict(),
|
||||
'model_size': self.model_size,
|
||||
'model_name_hf': self.model_name_hf,
|
||||
'embedding_dim': self.embedding_dim,
|
||||
'device': str(self.device)
|
||||
}
|
||||
|
||||
torch.save(model_data, model_path)
|
||||
print(f"模型已保存到: {model_path}")
|
||||
|
||||
def load_model(self, model_path: str) -> None:
|
||||
"""加载模型"""
|
||||
if not os.path.exists(model_path):
|
||||
raise FileNotFoundError(f"模型文件不存在: {model_path}")
|
||||
|
||||
# 加载模型数据
|
||||
model_data = torch.load(model_path, map_location=self.device)
|
||||
|
||||
# 验证模型大小匹配
|
||||
if model_data['model_size'] != self.model_size:
|
||||
raise ValueError(f"模型大小不匹配: 期望{self.model_size}, 实际{model_data['model_size']}")
|
||||
|
||||
# 加载embedding模型
|
||||
self._load_embedding_model()
|
||||
|
||||
# 重建分类器
|
||||
self.classifier_model = SentimentClassifier(
|
||||
self.embedding_model,
|
||||
model_data['embedding_dim']
|
||||
).to(self.device)
|
||||
self.classifier_model.classifier.load_state_dict(model_data['classifier_state_dict'])
|
||||
|
||||
self.model = self.classifier_model
|
||||
self.is_trained = True
|
||||
print(f"已加载Qwen3-Embedding-{self.model_size}模型: {model_path}")
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
parser = argparse.ArgumentParser(description='Qwen3-Embedding通用训练脚本')
|
||||
parser.add_argument('--model_size', type=str, choices=['0.6B', '4B', '8B'],
|
||||
help='模型大小')
|
||||
parser.add_argument('--train_path', type=str, default='./dataset/train.txt',
|
||||
help='训练数据路径')
|
||||
parser.add_argument('--test_path', type=str, default='./dataset/test.txt',
|
||||
help='测试数据路径')
|
||||
parser.add_argument('--model_path', type=str, help='模型保存路径(可选)')
|
||||
parser.add_argument('--epochs', type=int, default=5, help='训练轮数')
|
||||
parser.add_argument('--batch_size', type=int, help='批大小(可选,使用推荐值)')
|
||||
parser.add_argument('--learning_rate', type=float, help='学习率(可选,使用推荐值)')
|
||||
parser.add_argument('--eval_only', action='store_true', help='仅评估模式')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 如果没有指定模型大小,则询问用户
|
||||
if not args.model_size:
|
||||
print("Qwen3-Embedding模型训练")
|
||||
print("="*40)
|
||||
print("可用模型大小:")
|
||||
print(" 1. 0.6B - 轻量级,训练快速,显存需求约2GB")
|
||||
print(" 2. 4B - 中等规模,性能均衡,显存需求约8GB")
|
||||
print(" 3. 8B - 大规模,性能最佳,显存需求约16GB")
|
||||
|
||||
while True:
|
||||
choice = input("\n请选择模型大小 (1/2/3): ").strip()
|
||||
if choice == '1':
|
||||
args.model_size = '0.6B'
|
||||
break
|
||||
elif choice == '2':
|
||||
args.model_size = '4B'
|
||||
break
|
||||
elif choice == '3':
|
||||
args.model_size = '8B'
|
||||
break
|
||||
else:
|
||||
print("无效选择,请输入 1、2 或 3")
|
||||
|
||||
print(f"已选择: Qwen3-Embedding-{args.model_size}")
|
||||
print()
|
||||
|
||||
# 确保models目录存在
|
||||
os.makedirs('./models', exist_ok=True)
|
||||
|
||||
# 创建模型
|
||||
model = Qwen3EmbeddingUniversal(args.model_size)
|
||||
|
||||
# 确定模型保存路径
|
||||
model_path = args.model_path or MODEL_PATHS["embedding"][args.model_size]
|
||||
|
||||
if args.eval_only:
|
||||
# 仅评估模式
|
||||
print(f"评估模式:加载Qwen3-Embedding-{args.model_size}模型")
|
||||
model.load_model(model_path)
|
||||
|
||||
_, test_data = BaseQwenModel.load_data(args.train_path, args.test_path)
|
||||
model.evaluate(test_data)
|
||||
else:
|
||||
# 训练模式
|
||||
train_data, test_data = BaseQwenModel.load_data(args.train_path, args.test_path)
|
||||
|
||||
# 准备训练参数
|
||||
train_kwargs = {'num_epochs': args.epochs}
|
||||
if args.batch_size:
|
||||
train_kwargs['batch_size'] = args.batch_size
|
||||
if args.learning_rate:
|
||||
train_kwargs['learning_rate'] = args.learning_rate
|
||||
|
||||
# 训练模型
|
||||
model.train(train_data, **train_kwargs)
|
||||
|
||||
# 评估模型
|
||||
model.evaluate(test_data)
|
||||
|
||||
# 保存模型
|
||||
model.save_model(model_path)
|
||||
|
||||
# 示例预测
|
||||
print(f"\nQwen3-Embedding-{args.model_size} 示例预测:")
|
||||
test_texts = [
|
||||
"今天天气真好,心情很棒",
|
||||
"这部电影太无聊了,浪费时间",
|
||||
"哈哈哈,太有趣了"
|
||||
]
|
||||
|
||||
for text in test_texts:
|
||||
pred, conf = model.predict_single(text)
|
||||
sentiment = "正面" if pred == 1 else "负面"
|
||||
print(f"文本: {text}")
|
||||
print(f"预测: {sentiment} (置信度: {conf:.4f})")
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,444 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Qwen3-LoRA通用训练脚本
|
||||
支持0.6B、4B、8B三种规模的模型
|
||||
"""
|
||||
import argparse
|
||||
import os
|
||||
import torch
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
AutoModelForCausalLM,
|
||||
TrainingArguments,
|
||||
Trainer,
|
||||
DataCollatorForLanguageModeling
|
||||
)
|
||||
from peft import LoraConfig, get_peft_model, TaskType, PeftModel
|
||||
from datasets import Dataset
|
||||
from typing import List, Tuple
|
||||
import warnings
|
||||
from tqdm import tqdm
|
||||
|
||||
from base_model import BaseQwenModel
|
||||
from models_config import QWEN3_MODELS, MODEL_PATHS
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
|
||||
class Qwen3LoRAUniversal(BaseQwenModel):
|
||||
"""通用Qwen3-LoRA模型"""
|
||||
|
||||
def __init__(self, model_size: str = "0.6B"):
|
||||
if model_size not in QWEN3_MODELS:
|
||||
raise ValueError(f"不支持的模型大小: {model_size}")
|
||||
|
||||
super().__init__(f"Qwen3-{model_size}-LoRA")
|
||||
self.model_size = model_size
|
||||
self.config = QWEN3_MODELS[model_size]
|
||||
self.model_name_hf = self.config["base_model"]
|
||||
|
||||
self.tokenizer = None
|
||||
self.base_model = None
|
||||
self.lora_model = None
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
def _load_base_model(self):
|
||||
"""加载Qwen3基础模型"""
|
||||
print(f"加载{self.model_size}基础模型: {self.model_name_hf}")
|
||||
|
||||
try:
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_hf)
|
||||
self.base_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.model_name_hf,
|
||||
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
||||
device_map="auto" if torch.cuda.is_available() else None
|
||||
)
|
||||
|
||||
# 设置pad_token
|
||||
if self.tokenizer.pad_token is None:
|
||||
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
|
||||
|
||||
print(f"{self.model_size}基础模型加载完成")
|
||||
|
||||
# 立即保存到本地缓存
|
||||
cache_dir = f"./models/qwen3-{self.model_size.lower()}"
|
||||
if not os.path.exists(cache_dir):
|
||||
print(f"保存模型到本地: {cache_dir}")
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
self.tokenizer.save_pretrained(cache_dir)
|
||||
self.base_model.save_pretrained(cache_dir)
|
||||
print(f"模型已保存到: {cache_dir}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"从Hugging Face加载失败: {e}")
|
||||
|
||||
# 尝试从本地缓存加载
|
||||
cache_dir = f"./models/qwen3-{self.model_size.lower()}"
|
||||
try:
|
||||
if os.path.exists(cache_dir):
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(cache_dir)
|
||||
self.base_model = AutoModelForCausalLM.from_pretrained(
|
||||
cache_dir,
|
||||
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
||||
device_map="auto" if torch.cuda.is_available() else None
|
||||
)
|
||||
|
||||
if self.tokenizer.pad_token is None:
|
||||
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
|
||||
|
||||
print(f"从本地缓存加载{self.model_size}模型成功")
|
||||
else:
|
||||
raise FileNotFoundError("本地缓存也不存在")
|
||||
|
||||
except Exception as e2:
|
||||
print(f"本地加载也失败: {e2}")
|
||||
print(f"正在下载{self.model_size}模型...")
|
||||
|
||||
# 创建缓存目录并下载
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_hf, cache_dir=cache_dir)
|
||||
self.base_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.model_name_hf,
|
||||
cache_dir=cache_dir,
|
||||
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
||||
device_map="auto" if torch.cuda.is_available() else None
|
||||
)
|
||||
|
||||
if self.tokenizer.pad_token is None:
|
||||
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
|
||||
|
||||
# 保存到本地
|
||||
self.tokenizer.save_pretrained(cache_dir)
|
||||
self.base_model.save_pretrained(cache_dir)
|
||||
print(f"{self.model_size}模型下载并保存到: {cache_dir}")
|
||||
|
||||
def _create_instruction_data(self, data: List[Tuple[str, int]]) -> Dataset:
|
||||
"""创建指令格式的训练数据"""
|
||||
instructions = []
|
||||
|
||||
for text, label in data:
|
||||
sentiment = "正面" if label == 1 else "负面"
|
||||
|
||||
# 构建指令格式
|
||||
instruction = f"请分析以下微博文本的情感倾向,回答'正面'或'负面'。\n\n文本:{text}\n\n情感:"
|
||||
response = sentiment
|
||||
|
||||
# 组合成完整的训练文本
|
||||
full_text = f"{instruction}{response}{self.tokenizer.eos_token}"
|
||||
|
||||
instructions.append({
|
||||
"instruction": instruction,
|
||||
"response": response,
|
||||
"text": full_text
|
||||
})
|
||||
|
||||
return Dataset.from_list(instructions)
|
||||
|
||||
def _tokenize_function(self, examples):
|
||||
"""分词函数"""
|
||||
tokenized = self.tokenizer(
|
||||
examples["text"],
|
||||
truncation=True,
|
||||
padding=False,
|
||||
max_length=512,
|
||||
return_tensors=None
|
||||
)
|
||||
|
||||
tokenized["labels"] = tokenized["input_ids"].copy()
|
||||
return tokenized
|
||||
|
||||
def _setup_lora(self, **kwargs):
|
||||
"""设置LoRA配置"""
|
||||
lora_r = kwargs.get('lora_r', self.config['lora_r'])
|
||||
lora_alpha = kwargs.get('lora_alpha', self.config['lora_alpha'])
|
||||
|
||||
lora_config = LoraConfig(
|
||||
task_type=TaskType.CAUSAL_LM,
|
||||
r=lora_r,
|
||||
lora_alpha=lora_alpha,
|
||||
lora_dropout=kwargs.get('lora_dropout', 0.1),
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
|
||||
)
|
||||
|
||||
self.lora_model = get_peft_model(self.base_model, lora_config)
|
||||
|
||||
print(f"LoRA配置完成 (r={lora_r}, alpha={lora_alpha})")
|
||||
print(f"可训练参数: {self.lora_model.num_parameters():,}")
|
||||
print(f"参数比例: {self.lora_model.num_parameters() / self.lora_model.base_model.num_parameters() * 100:.2f}%")
|
||||
|
||||
return lora_config
|
||||
|
||||
def train(self, train_data: List[Tuple[str, int]], **kwargs) -> None:
|
||||
"""训练模型"""
|
||||
print(f"开始训练 Qwen3-{self.model_size}-LoRA 模型...")
|
||||
|
||||
# 加载基础模型
|
||||
self._load_base_model()
|
||||
|
||||
# 设置LoRA
|
||||
self._setup_lora(**kwargs)
|
||||
|
||||
# 超参数(使用配置文件的推荐值或用户指定值)
|
||||
num_epochs = kwargs.get('num_epochs', 3)
|
||||
batch_size = kwargs.get('batch_size', self.config['recommended_batch_size'] // 2) # LoRA需要更少批大小
|
||||
learning_rate = kwargs.get('learning_rate', self.config['recommended_lr'] / 2) # LoRA使用更小学习率
|
||||
output_dir = kwargs.get('output_dir', f'./models/qwen3_lora_{self.model_size.lower()}_checkpoints')
|
||||
|
||||
print(f"超参数: epochs={num_epochs}, batch_size={batch_size}, lr={learning_rate}")
|
||||
|
||||
# 创建指令格式数据
|
||||
train_dataset = self._create_instruction_data(train_data)
|
||||
|
||||
# 分词
|
||||
tokenized_dataset = train_dataset.map(
|
||||
self._tokenize_function,
|
||||
batched=True,
|
||||
remove_columns=train_dataset.column_names
|
||||
)
|
||||
|
||||
# 训练参数
|
||||
training_args = TrainingArguments(
|
||||
output_dir=output_dir,
|
||||
num_train_epochs=num_epochs,
|
||||
per_device_train_batch_size=batch_size,
|
||||
gradient_accumulation_steps=2,
|
||||
learning_rate=learning_rate,
|
||||
logging_steps=10,
|
||||
save_steps=100,
|
||||
save_total_limit=2,
|
||||
remove_unused_columns=False,
|
||||
dataloader_drop_last=False,
|
||||
report_to=None,
|
||||
)
|
||||
|
||||
# 数据整理器
|
||||
data_collator = DataCollatorForLanguageModeling(
|
||||
tokenizer=self.tokenizer,
|
||||
mlm=False,
|
||||
)
|
||||
|
||||
# 创建训练器
|
||||
trainer = Trainer(
|
||||
model=self.lora_model,
|
||||
args=training_args,
|
||||
train_dataset=tokenized_dataset,
|
||||
data_collator=data_collator,
|
||||
tokenizer=self.tokenizer,
|
||||
)
|
||||
|
||||
# 开始训练
|
||||
print(f"开始LoRA微调...")
|
||||
trainer.train()
|
||||
|
||||
# 保存模型
|
||||
self.lora_model.save_pretrained(output_dir)
|
||||
self.tokenizer.save_pretrained(output_dir)
|
||||
|
||||
self.model = self.lora_model
|
||||
self.is_trained = True
|
||||
print(f"Qwen3-{self.model_size}-LoRA 模型训练完成!")
|
||||
|
||||
def _extract_sentiment(self, generated_text: str, instruction: str) -> int:
|
||||
"""从生成的文本中提取情感标签"""
|
||||
response = generated_text[len(instruction):].strip()
|
||||
|
||||
if "正面" in response:
|
||||
return 1
|
||||
elif "负面" in response:
|
||||
return 0
|
||||
else:
|
||||
return 0
|
||||
|
||||
def predict(self, texts: List[str]) -> List[int]:
|
||||
"""预测文本情感"""
|
||||
if not self.is_trained:
|
||||
raise ValueError(f"模型 {self.model_name} 尚未训练")
|
||||
|
||||
predictions = []
|
||||
|
||||
self.lora_model.eval()
|
||||
with torch.no_grad():
|
||||
for text in tqdm(texts, desc=f"Qwen3-{self.model_size}预测中"):
|
||||
pred, _ = self.predict_single(text)
|
||||
predictions.append(pred)
|
||||
|
||||
return predictions
|
||||
|
||||
def predict_single(self, text: str) -> Tuple[int, float]:
|
||||
"""预测单条文本的情感"""
|
||||
if not self.is_trained:
|
||||
raise ValueError(f"模型 {self.model_name} 尚未训练")
|
||||
|
||||
# 构建指令
|
||||
instruction = f"请分析以下微博文本的情感倾向,回答'正面'或'负面'。\n\n文本:{text}\n\n情感:"
|
||||
|
||||
# 分词
|
||||
inputs = self.tokenizer(instruction, return_tensors="pt")
|
||||
if torch.cuda.is_available():
|
||||
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
||||
|
||||
# 生成回答
|
||||
self.lora_model.eval()
|
||||
with torch.no_grad():
|
||||
outputs = self.lora_model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=10,
|
||||
do_sample=True,
|
||||
temperature=0.1,
|
||||
pad_token_id=self.tokenizer.pad_token_id,
|
||||
eos_token_id=self.tokenizer.eos_token_id,
|
||||
)
|
||||
|
||||
# 解码生成的文本
|
||||
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||||
|
||||
# 提取情感标签
|
||||
prediction = self._extract_sentiment(generated_text, instruction)
|
||||
confidence = 0.8 # 生成式模型的置信度计算较复杂,这里给个固定值
|
||||
|
||||
return prediction, confidence
|
||||
|
||||
def save_model(self, model_path: str = None) -> None:
|
||||
"""保存模型"""
|
||||
if not self.is_trained:
|
||||
raise ValueError(f"模型 {self.model_name} 尚未训练")
|
||||
|
||||
if model_path is None:
|
||||
model_path = MODEL_PATHS["lora"][self.model_size]
|
||||
|
||||
os.makedirs(model_path, exist_ok=True)
|
||||
|
||||
# 保存LoRA权重
|
||||
self.lora_model.save_pretrained(model_path)
|
||||
self.tokenizer.save_pretrained(model_path)
|
||||
|
||||
print(f"LoRA模型已保存到: {model_path}")
|
||||
|
||||
def load_model(self, model_path: str) -> None:
|
||||
"""加载模型"""
|
||||
if not os.path.exists(model_path):
|
||||
raise FileNotFoundError(f"模型文件不存在: {model_path}")
|
||||
|
||||
# 加载基础模型
|
||||
self._load_base_model()
|
||||
|
||||
# 加载LoRA权重
|
||||
self.lora_model = PeftModel.from_pretrained(self.base_model, model_path)
|
||||
|
||||
self.model = self.lora_model
|
||||
self.is_trained = True
|
||||
print(f"已加载Qwen3-{self.model_size}-LoRA模型: {model_path}")
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
parser = argparse.ArgumentParser(description='Qwen3-LoRA通用训练脚本')
|
||||
parser.add_argument('--model_size', type=str, choices=['0.6B', '4B', '8B'],
|
||||
help='模型大小')
|
||||
parser.add_argument('--train_path', type=str, default='./dataset/train.txt',
|
||||
help='训练数据路径')
|
||||
parser.add_argument('--test_path', type=str, default='./dataset/test.txt',
|
||||
help='测试数据路径')
|
||||
parser.add_argument('--model_path', type=str, help='模型保存路径(可选)')
|
||||
parser.add_argument('--epochs', type=int, default=3, help='训练轮数')
|
||||
parser.add_argument('--batch_size', type=int, help='批大小(可选,使用推荐值)')
|
||||
parser.add_argument('--learning_rate', type=float, help='学习率(可选,使用推荐值)')
|
||||
parser.add_argument('--lora_r', type=int, help='LoRA秩(可选,使用推荐值)')
|
||||
parser.add_argument('--max_samples', type=int, default=1000, help='最大训练样本数')
|
||||
parser.add_argument('--eval_only', action='store_true', help='仅评估模式')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 如果没有指定模型大小,则询问用户
|
||||
if not args.model_size:
|
||||
print("Qwen3-LoRA模型训练")
|
||||
print("="*40)
|
||||
print("可用模型大小:")
|
||||
print(" 1. 0.6B - 轻量级,训练快速,显存需求约4GB")
|
||||
print(" 2. 4B - 中等规模,性能均衡,显存需求约16GB")
|
||||
print(" 3. 8B - 大规模,性能最佳,显存需求约32GB")
|
||||
print("\n注意: LoRA微调比Embedding方法需要更多显存")
|
||||
|
||||
while True:
|
||||
choice = input("\n请选择模型大小 (1/2/3): ").strip()
|
||||
if choice == '1':
|
||||
args.model_size = '0.6B'
|
||||
break
|
||||
elif choice == '2':
|
||||
args.model_size = '4B'
|
||||
break
|
||||
elif choice == '3':
|
||||
args.model_size = '8B'
|
||||
break
|
||||
else:
|
||||
print("无效选择,请输入 1、2 或 3")
|
||||
|
||||
print(f"已选择: Qwen3-{args.model_size} + LoRA")
|
||||
print()
|
||||
|
||||
# 确保models目录存在
|
||||
os.makedirs('./models', exist_ok=True)
|
||||
|
||||
# 创建模型
|
||||
model = Qwen3LoRAUniversal(args.model_size)
|
||||
|
||||
# 确定模型保存路径
|
||||
model_path = args.model_path or MODEL_PATHS["lora"][args.model_size]
|
||||
|
||||
if args.eval_only:
|
||||
# 仅评估模式
|
||||
print(f"评估模式:加载Qwen3-{args.model_size}-LoRA模型")
|
||||
model.load_model(model_path)
|
||||
|
||||
_, test_data = BaseQwenModel.load_data(args.train_path, args.test_path)
|
||||
# LoRA评估使用少量数据
|
||||
test_subset = test_data[:50]
|
||||
model.evaluate(test_subset)
|
||||
else:
|
||||
# 训练模式
|
||||
train_data, test_data = BaseQwenModel.load_data(args.train_path, args.test_path)
|
||||
|
||||
# 由于LoRA训练资源消耗大,使用部分数据
|
||||
train_subset = train_data[:args.max_samples]
|
||||
print(f"使用 {len(train_subset)} 条数据进行LoRA训练")
|
||||
|
||||
# 准备训练参数
|
||||
train_kwargs = {'num_epochs': args.epochs}
|
||||
if args.batch_size:
|
||||
train_kwargs['batch_size'] = args.batch_size
|
||||
if args.learning_rate:
|
||||
train_kwargs['learning_rate'] = args.learning_rate
|
||||
if args.lora_r:
|
||||
train_kwargs['lora_r'] = args.lora_r
|
||||
|
||||
# 训练模型
|
||||
model.train(train_subset, **train_kwargs)
|
||||
|
||||
# 评估模型(使用少量测试数据)
|
||||
test_subset = test_data[:50]
|
||||
model.evaluate(test_subset)
|
||||
|
||||
# 保存模型
|
||||
model.save_model(model_path)
|
||||
|
||||
# 示例预测
|
||||
print(f"\nQwen3-{args.model_size}-LoRA 示例预测:")
|
||||
test_texts = [
|
||||
"今天天气真好,心情很棒",
|
||||
"这部电影太无聊了,浪费时间",
|
||||
"哈哈哈,太有趣了"
|
||||
]
|
||||
|
||||
for text in test_texts:
|
||||
pred, conf = model.predict_single(text)
|
||||
sentiment = "正面" if pred == 1 else "负面"
|
||||
print(f"文本: {text}")
|
||||
print(f"预测: {sentiment} (置信度: {conf:.4f})")
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -4,10 +4,95 @@
|
||||
|
||||
## 项目背景
|
||||
|
||||
本文件夹专门用于基于阿里Qwen3系列模型的微博情感分析任务。根据最新的模型评测结果,Qwen3的小参数模型(如0.6B、4B、8B、14B)在话题识别、情感分析等相对简单的自然语言处理任务上表现优异,超越了传统的BERT等基础模型。
|
||||
本文件夹专门用于基于阿里Qwen3系列模型的微博情感分析任务。根据最新的模型评测结果,Qwen3的小参数模型(0.6B、4B、8B)在话题识别、情感分析等相对简单的自然语言处理任务上表现优异,超越了传统的BERT等基础模型。
|
||||
|
||||
qwen 0.6B模型加线性分类器,做特定领域的文本分类和序列标注,优于bert,也优于235B的qwen3 few shot learning。在算力有限的情况下,性价比很高...
|
||||
|
||||
在经过了一些相关的调研之后,我觉的将Qwen3的一些小参数模型用在本系统中是一个不错的选择。
|
||||
|
||||
虽然这个参数在LLM时代算小,但作为个人开发者计算资源有限,微调他们还是实属不易。
|
||||
|
||||
## 问题探究
|
||||
|
||||
另外我也比较好奇一个问题:例如对于Qwen3-Embedding-0.6B跟Qwen3-0.6B这两个模型,前者我接一个分类头做情感二分类,后者我进行lora微调,在同样的数据集上训练,哪个的效果更好,各有什么优势?
|
||||
|
||||
**在绝大多数情况下,使用 Qwen3-0.6B 进行 LoRA 微调的效果会显著优于使用 Qwen3-Embedding-0.6B 外接分类头,但性能不如直接接分类头的。**
|
||||
|
||||
因此本模块对于所有参数的都提供**微调**与**嵌入再接分类头**两个版本,供大家取舍。
|
||||
|
||||
我们通过一个表格来清晰地展示两者的区别和优劣势:
|
||||
|
||||
| 特性 / 维度 | 方法 A: `Qwen3-Embedding-0.6B` + 分类头 | 方法 B: `Qwen3-0.6B` + LoRA 微调 |
|
||||
| ----------------- | ------------------------------------------------------------ | ------------------------------------------------------------ |
|
||||
| **核心思想** | **表示学习 (Representation Learning)** | **指令遵循 (Instruction Following)** |
|
||||
| **模型学习方式** | 冻结Embedding模型,只训练一个非常小的分类头(如`nn.Linear`),学习从固定文本向量到情感标签的映射。 | 冻结大部分基础模型参数,通过训练LoRA“适配器”来微调模型**内部的注意力机制和知识表达**,使其学会按指令生成特定答案。 |
|
||||
| **性能上限** | **较低**。模型的理解能力被`Qwen3-Embedding-0.6B`的通用语义表示所限制,无法学习你数据集中特有的、细微的情感模式。 | **更高**。模型在微调中调整了自身对语言的理解方式,以适应你的特定任务和数据分布,能更好地捕捉讽刺、网络用语等复杂情感。 |
|
||||
| **灵活性** | **低**。模型只能做这一件事:输出分类标签。无法扩展。 | **高**。模型学会的是一个“任务技能”。你可以轻松修改指令,让它输出“积极/消极/中性”,甚至“为什么这是积极的?”。 |
|
||||
| **训练资源开销** | **极低**。只需训练一个几KB到几MB的分类头,普通CPU都能完成。显存占用非常小。 | **较高**。虽然LoRA效率很高,但仍需在GPU上进行,需要加载整个0.6B模型和LoRA参数到显存中进行反向传播。 |
|
||||
| **推理速度/成本** | **极快、极低**。一次前向传播即可获得Embedding向量,分类头计算可忽略不计。非常适合大规模、低延迟的生产环境。 | **较慢、较高**。需要进行自回归生成(一个词一个词地蹦),即使答案很短(如“积极”),也比一次性前向传播慢几个数量级。 |
|
||||
| **实现复杂度** | **简单**。遵循BERT时代的技术范式,流程成熟,代码直观。 | **中等**。需要构建指令模板、配置LoRA参数、使用SFTTrainer等,比前者稍复杂,但已有成熟框架支持。 |
|
||||
|
||||
## 使用说明
|
||||
|
||||
### 环境配置
|
||||
```bash
|
||||
# 安装依赖
|
||||
pip install -r requirements.txt
|
||||
|
||||
# 激活pytorch环境
|
||||
conda activate 你的环境名
|
||||
```
|
||||
|
||||
### 训练模型
|
||||
|
||||
**Embedding + 分类头方法:**
|
||||
```bash
|
||||
python qwen3_embedding_universal.py
|
||||
# 程序会询问选择模型大小(0.6B/4B/8B)
|
||||
```
|
||||
|
||||
**LoRA微调方法:**
|
||||
```bash
|
||||
python qwen3_lora_universal.py
|
||||
# 程序会询问选择模型大小(0.6B/4B/8B)
|
||||
```
|
||||
|
||||
**命令行参数:**
|
||||
```bash
|
||||
# 直接指定模型
|
||||
python qwen3_embedding_universal.py --model_size 0.6B
|
||||
python qwen3_lora_universal.py --model_size 4B
|
||||
|
||||
# 自定义参数
|
||||
python qwen3_embedding_universal.py --model_size 8B --epochs 10 --batch_size 16
|
||||
```
|
||||
|
||||
### 预测使用
|
||||
|
||||
**交互式预测:**
|
||||
```bash
|
||||
python predict_universal.py
|
||||
# 程序会让你选择具体的模型和方法
|
||||
```
|
||||
|
||||
**命令行预测:**
|
||||
```bash
|
||||
# 指定模型预测
|
||||
python predict_universal.py --model_type embedding --model_size 0.6B --text "今天天气真好"
|
||||
|
||||
# 加载所有模型
|
||||
python predict_universal.py --load_all --text "这个电影太棒了"
|
||||
```
|
||||
|
||||
### 注意事项
|
||||
|
||||
1. **显存要求**:
|
||||
- 0.6B: 最低2GB显存
|
||||
- 4B: 最低8GB显存
|
||||
- 8B: 最低16GB显存
|
||||
|
||||
2. **数据格式**:每行格式为`文本内容\t标签`,标签为0(负面)或1(正面)
|
||||
|
||||
3. **模型选择**:初次使用建议从0.6B模型开始测试
|
||||
|
||||
4. **训练时间**:LoRA微调比Embedding方法耗时更长,建议使用GPU加速
|
||||
@@ -0,0 +1,11 @@
|
||||
torch>=2.0.0
|
||||
transformers>=4.51.0
|
||||
peft>=0.7.0
|
||||
datasets>=2.14.0
|
||||
accelerate>=0.25.0
|
||||
scikit-learn>=1.3.0
|
||||
pandas>=1.5.0
|
||||
numpy>=1.24.0
|
||||
tqdm>=4.65.0
|
||||
sentence-transformers>=2.7.0
|
||||
bitsandbytes>=0.41.0
|
||||
Reference in New Issue
Block a user