""" This class is used to test the init and fit functions of the TopicGPT class """ import os import sys import inspect currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) parentdir = os.path.dirname(currentdir) sys.path.insert(0, f"{parentdir}/src") from topicgpt.TopicGPT import TopicGPT sys.path.insert(0, parentdir) import openai import pickle import unittest from src.topicgpt.TopicRepresentation import Topic from src.topicgpt.Clustering import Clustering_and_DimRed from src.topicgpt.TopwordEnhancement import TopwordEnhancement from src.topicgpt.TopicPrompting import TopicPrompting class TestTopicGPT_init_and_fit(unittest.TestCase): """ Test the init and fit functions of the TopicGPT class """ @classmethod def setUpClass(cls, sample_size = 0.1): """ load the necessary data and only keep a sample of it """ print("Setting up class...") cls.api_key_openai = os.environ.get('api_key') # TODO: The 'openai.organization' option isn't read in the client API. You will need to pass it when you instantiate the client, e.g. 'OpenAI(organization=os.environ.get('OPENAI_ORG'))' # openai.organization = os.environ.get('OPENAI_ORG') with open("Data/Emebeddings/embeddings_20ng_raw.pkl", "rb") as f: data_raw = pickle.load(f) corpus = data_raw["corpus"] doc_embeddings = data_raw["embeddings"] n_docs = int(len(corpus) * sample_size) cls.corpus = corpus[:n_docs] cls.doc_embeddings = doc_embeddings[:n_docs] print("Using {} out of {} documents".format(n_docs, len(data_raw["corpus"]))) with open("Data/Emebeddings/embeddings_20ng_vocab.pkl", "rb") as f: cls.embeddings_vocab = pickle.load(f) def test_init(self): """ test the init function of the TopicGPT class """ print("Testing init...") topicgpt = TopicGPT(api_key = self.api_key_openai) self.assertTrue(isinstance(topicgpt, TopicGPT)) topicgpt = TopicGPT(api_key = self.api_key_openai, n_topics= 20) self.assertTrue(isinstance(topicgpt, TopicGPT)) topicgpt = TopicGPT(api_key = self.api_key_openai, n_topics= 20, corpus_instruction="This is a corpus instruction", document_embeddings = self.doc_embeddings, vocab_embeddings= self.embeddings_vocab) self.assertTrue(isinstance(topicgpt, TopicGPT)) # check if assertions are triggered with self.assertRaises(AssertionError): topicgpt = TopicGPT(api_key = None, n_topics= 32, openai_prompting_model="gpt-4", max_number_of_tokens=8000, corpus_instruction="This is a corpus instruction") with self.assertRaises(AssertionError): topicgpt = TopicGPT(api_key = self.api_key_openai, n_topics= 0, max_number_of_tokens=8000, corpus_instruction="This is a corpus instruction") with self.assertRaises(AssertionError): topicgpt = TopicGPT(api_key = self.api_key_openai, n_topics= 20, max_number_of_tokens=0, corpus_instruction="This is a corpus instruction") def test_fit(self): """ test the fit function of the TopicGPT class """ print("Testing fit...") def instance_test(topicgpt): topicgpt.fit(self.corpus) self.assertTrue(hasattr(topicgpt, "vocab")) self.assertTrue(hasattr(topicgpt, "topic_lis")) self.assertTrue(isinstance(topicgpt.vocab, list)) self.assertTrue(isinstance(topicgpt.vocab[0], str)) self.assertTrue(isinstance(topicgpt.topic_lis, list)) try: self.assertTrue(type(topicgpt.topic_lis[0]) == Topic) except AssertionError as e: print(e) print(type(topicgpt.topic_lis[0])) print(topicgpt.topic_lis[0]) if topicgpt.n_topics is not None: self.assertTrue(len(topicgpt.topic_lis) == topicgpt.n_topics) self.assertTrue(topicgpt.topic_lis == topicgpt.topic_prompting.topic_lis) self.assertTrue(topicgpt.vocab == topicgpt.topic_prompting.vocab) self.assertTrue(topicgpt.vocab_embeddings == topicgpt.topic_prompting.vocab_embeddings) topicgpt1 = TopicGPT(api_key = self.api_key_openai, n_topics= 20, document_embeddings = self.doc_embeddings, vocab_embeddings = self.embeddings_vocab) topicgpt2 = TopicGPT(api_key = self.api_key_openai, n_topics= None, document_embeddings = self.doc_embeddings, vocab_embeddings = self.embeddings_vocab) topicgpt3 = TopicGPT(api_key=self.api_key_openai, n_topics = 1, document_embeddings = self.doc_embeddings, vocab_embeddings = self.embeddings_vocab, n_topwords=10, n_topwords_description=10, topword_extraction_methods=["cosine_similarity"]) clusterer4 = Clustering_and_DimRed( n_dims_umap = 10, n_neighbors_umap = 20, min_cluster_size_hdbscan = 10, number_clusters_hdbscan= 10 # use only 10 clusters ) topword_enhancement4 = TopwordEnhancement(api_key = self.api_key_openai) topic_prompting4 = TopicPrompting( api_key = self.api_key_openai, enhancer = topword_enhancement4, topic_lis = None ) topicgpt4 = TopicGPT(api_key=self.api_key_openai, n_topics= None, document_embeddings = self.doc_embeddings, vocab_embeddings = self.embeddings_vocab, topic_prompting = topic_prompting4, clusterer = clusterer4, topword_extraction_methods=["tfidf"]) topic_gpt_list = [topicgpt1, topicgpt2, topicgpt3, topicgpt4] for topic_gpt in topic_gpt_list: instance_test(topic_gpt) if __name__ == "__main__": unittest.main()