from topicgpt.TopicRepresentation import Topic import unittest from sklearn.datasets import fetch_20newsgroups from topicgpt.TopicGPT import TopicGPT class QuickTestTopicGPT_init_and_fit(unittest.TestCase): """ Run some basic tests on TopicGPT that do not require any saved data """ @classmethod def setUpClass(cls, sample_size:int = 500): """ download the necessary data and only keep a sample of it params: api_key: the openai api key sample_size: the number of documents to use for the test """ data = fetch_20newsgroups(subset='all', remove=('headers', 'footers', 'quotes')) #download the 20 Newsgroups dataset corpus = data['data']# just select the first 1000 documents for this example corpus = [doc for doc in corpus if doc != ""] corpus = corpus[:sample_size] cls.corpus = corpus def setUp(self): self.api_key_openai = api_key def test_init(self): """ test the init function of the TopicGPT class """ print("Testing init...") topicgpt = TopicGPT(api_key = self.api_key_openai) self.assertTrue(isinstance(topicgpt, TopicGPT)) topicgpt = TopicGPT(api_key = self.api_key_openai, n_topics= 20) self.assertTrue(isinstance(topicgpt, TopicGPT)) topicgpt = TopicGPT(api_key = self.api_key_openai, n_topics= 20, corpus_instruction="This is a corpus instruction") self.assertTrue(isinstance(topicgpt, TopicGPT)) # check if assertions are triggered with self.assertRaises(AssertionError): topicgpt = TopicGPT(api_key = None, n_topics= 32, openai_prompting_model="gpt-4", max_number_of_tokens=8000, corpus_instruction="This is a corpus instruction") with self.assertRaises(AssertionError): topicgpt = TopicGPT(api_key = self.api_key_openai, n_topics= 0, max_number_of_tokens=8000, corpus_instruction="This is a corpus instruction") with self.assertRaises(AssertionError): topicgpt = TopicGPT(api_key = self.api_key_openai, n_topics= 20, max_number_of_tokens=0, corpus_instruction="This is a corpus instruction") def test_fit(self): """ test the fit function of the TopicGPT class """ print("Testing fit...") def instance_test(topicgpt): topicgpt.fit(self.corpus) self.assertTrue(hasattr(topicgpt, "vocab")) self.assertTrue(hasattr(topicgpt, "topic_lis")) self.assertTrue(isinstance(topicgpt.vocab, list)) self.assertTrue(isinstance(topicgpt.vocab[0], str)) self.assertTrue(isinstance(topicgpt.topic_lis, list)) self.assertTrue(type(topicgpt.topic_lis[0]) == Topic) if topicgpt.n_topics is not None: self.assertTrue(len(topicgpt.topic_lis) == topicgpt.n_topics) self.assertTrue(topicgpt.topic_lis == topicgpt.topic_prompting.topic_lis) self.assertTrue(topicgpt.vocab == topicgpt.topic_prompting.vocab) self.assertTrue(topicgpt.vocab_embeddings == topicgpt.topic_prompting.vocab_embeddings) topicgpt1 = TopicGPT(api_key = self.api_key_openai, n_topics = 1) topic_gpt_list = [topicgpt1] for topic_gpt in topic_gpt_list: instance_test(topic_gpt) import sys if __name__ == "__main__": for i, arg in enumerate(sys.argv): if arg == "--api-key": api_key = sys.argv.pop(i + 1) sys.argv.pop(i) break if api_key is None: print("API key must be provided with --api-key") sys.exit(1) unittest.main()