The LLM-based topic recognition model is complete and adapted to quickly updating Weibo topics.

This commit is contained in:
戒酒的李白
2025-08-07 11:14:38 +08:00
parent 1e780876c9
commit d88d5edd99
32 changed files with 8352 additions and 1 deletions
@@ -0,0 +1,120 @@
from topicgpt.TopicRepresentation import Topic
import unittest
from sklearn.datasets import fetch_20newsgroups
from topicgpt.TopicGPT import TopicGPT
class QuickTestTopicGPT_init_and_fit(unittest.TestCase):
"""
Run some basic tests on TopicGPT that do not require any saved data
"""
@classmethod
def setUpClass(cls, sample_size:int = 500):
"""
download the necessary data and only keep a sample of it
params:
api_key: the openai api key
sample_size: the number of documents to use for the test
"""
data = fetch_20newsgroups(subset='all', remove=('headers', 'footers', 'quotes')) #download the 20 Newsgroups dataset
corpus = data['data']# just select the first 1000 documents for this example
corpus = [doc for doc in corpus if doc != ""]
corpus = corpus[:sample_size]
cls.corpus = corpus
def setUp(self):
self.api_key_openai = api_key
def test_init(self):
"""
test the init function of the TopicGPT class
"""
print("Testing init...")
topicgpt = TopicGPT(api_key = self.api_key_openai)
self.assertTrue(isinstance(topicgpt, TopicGPT))
topicgpt = TopicGPT(api_key = self.api_key_openai,
n_topics= 20)
self.assertTrue(isinstance(topicgpt, TopicGPT))
topicgpt = TopicGPT(api_key = self.api_key_openai,
n_topics= 20,
corpus_instruction="This is a corpus instruction")
self.assertTrue(isinstance(topicgpt, TopicGPT))
# check if assertions are triggered
with self.assertRaises(AssertionError):
topicgpt = TopicGPT(api_key = None,
n_topics= 32,
openai_prompting_model="gpt-4",
max_number_of_tokens=8000,
corpus_instruction="This is a corpus instruction")
with self.assertRaises(AssertionError):
topicgpt = TopicGPT(api_key = self.api_key_openai,
n_topics= 0,
max_number_of_tokens=8000,
corpus_instruction="This is a corpus instruction")
with self.assertRaises(AssertionError):
topicgpt = TopicGPT(api_key = self.api_key_openai,
n_topics= 20,
max_number_of_tokens=0,
corpus_instruction="This is a corpus instruction")
def test_fit(self):
"""
test the fit function of the TopicGPT class
"""
print("Testing fit...")
def instance_test(topicgpt):
topicgpt.fit(self.corpus)
self.assertTrue(hasattr(topicgpt, "vocab"))
self.assertTrue(hasattr(topicgpt, "topic_lis"))
self.assertTrue(isinstance(topicgpt.vocab, list))
self.assertTrue(isinstance(topicgpt.vocab[0], str))
self.assertTrue(isinstance(topicgpt.topic_lis, list))
self.assertTrue(type(topicgpt.topic_lis[0]) == Topic)
if topicgpt.n_topics is not None:
self.assertTrue(len(topicgpt.topic_lis) == topicgpt.n_topics)
self.assertTrue(topicgpt.topic_lis == topicgpt.topic_prompting.topic_lis)
self.assertTrue(topicgpt.vocab == topicgpt.topic_prompting.vocab)
self.assertTrue(topicgpt.vocab_embeddings == topicgpt.topic_prompting.vocab_embeddings)
topicgpt1 = TopicGPT(api_key = self.api_key_openai, n_topics = 1)
topic_gpt_list = [topicgpt1]
for topic_gpt in topic_gpt_list:
instance_test(topic_gpt)
import sys
if __name__ == "__main__":
for i, arg in enumerate(sys.argv):
if arg == "--api-key":
api_key = sys.argv.pop(i + 1)
sys.argv.pop(i)
break
if api_key is None:
print("API key must be provided with --api-key")
sys.exit(1)
unittest.main()