Files
bettafish-company/LLMTopicDetection_TopicGPT/test/TestTopicGPT_init_and_fit.py
T

178 lines
6.7 KiB
Python

"""
This class is used to test the init and fit functions of the TopicGPT class
"""
import os
import sys
import inspect
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.insert(0, f"{parentdir}/src")
from topicgpt.TopicGPT import TopicGPT
sys.path.insert(0, parentdir)
import openai
import pickle
import unittest
from src.topicgpt.TopicRepresentation import Topic
from src.topicgpt.Clustering import Clustering_and_DimRed
from src.topicgpt.TopwordEnhancement import TopwordEnhancement
from src.topicgpt.TopicPrompting import TopicPrompting
class TestTopicGPT_init_and_fit(unittest.TestCase):
"""
Test the init and fit functions of the TopicGPT class
"""
@classmethod
def setUpClass(cls, sample_size = 0.1):
"""
load the necessary data and only keep a sample of it
"""
print("Setting up class...")
cls.api_key_openai = os.environ.get('api_key')
# TODO: The 'openai.organization' option isn't read in the client API. You will need to pass it when you instantiate the client, e.g. 'OpenAI(organization=os.environ.get('OPENAI_ORG'))'
# openai.organization = os.environ.get('OPENAI_ORG')
with open("Data/Emebeddings/embeddings_20ng_raw.pkl", "rb") as f:
data_raw = pickle.load(f)
corpus = data_raw["corpus"]
doc_embeddings = data_raw["embeddings"]
n_docs = int(len(corpus) * sample_size)
cls.corpus = corpus[:n_docs]
cls.doc_embeddings = doc_embeddings[:n_docs]
print("Using {} out of {} documents".format(n_docs, len(data_raw["corpus"])))
with open("Data/Emebeddings/embeddings_20ng_vocab.pkl", "rb") as f:
cls.embeddings_vocab = pickle.load(f)
def test_init(self):
"""
test the init function of the TopicGPT class
"""
print("Testing init...")
topicgpt = TopicGPT(api_key = self.api_key_openai)
self.assertTrue(isinstance(topicgpt, TopicGPT))
topicgpt = TopicGPT(api_key = self.api_key_openai,
n_topics= 20)
self.assertTrue(isinstance(topicgpt, TopicGPT))
topicgpt = TopicGPT(api_key = self.api_key_openai,
n_topics= 20,
corpus_instruction="This is a corpus instruction",
document_embeddings = self.doc_embeddings,
vocab_embeddings= self.embeddings_vocab)
self.assertTrue(isinstance(topicgpt, TopicGPT))
# check if assertions are triggered
with self.assertRaises(AssertionError):
topicgpt = TopicGPT(api_key = None,
n_topics= 32,
openai_prompting_model="gpt-4",
max_number_of_tokens=8000,
corpus_instruction="This is a corpus instruction")
with self.assertRaises(AssertionError):
topicgpt = TopicGPT(api_key = self.api_key_openai,
n_topics= 0,
max_number_of_tokens=8000,
corpus_instruction="This is a corpus instruction")
with self.assertRaises(AssertionError):
topicgpt = TopicGPT(api_key = self.api_key_openai,
n_topics= 20,
max_number_of_tokens=0,
corpus_instruction="This is a corpus instruction")
def test_fit(self):
"""
test the fit function of the TopicGPT class
"""
print("Testing fit...")
def instance_test(topicgpt):
topicgpt.fit(self.corpus)
self.assertTrue(hasattr(topicgpt, "vocab"))
self.assertTrue(hasattr(topicgpt, "topic_lis"))
self.assertTrue(isinstance(topicgpt.vocab, list))
self.assertTrue(isinstance(topicgpt.vocab[0], str))
self.assertTrue(isinstance(topicgpt.topic_lis, list))
try:
self.assertTrue(type(topicgpt.topic_lis[0]) == Topic)
except AssertionError as e:
print(e)
print(type(topicgpt.topic_lis[0]))
print(topicgpt.topic_lis[0])
if topicgpt.n_topics is not None:
self.assertTrue(len(topicgpt.topic_lis) == topicgpt.n_topics)
self.assertTrue(topicgpt.topic_lis == topicgpt.topic_prompting.topic_lis)
self.assertTrue(topicgpt.vocab == topicgpt.topic_prompting.vocab)
self.assertTrue(topicgpt.vocab_embeddings == topicgpt.topic_prompting.vocab_embeddings)
topicgpt1 = TopicGPT(api_key = self.api_key_openai,
n_topics= 20,
document_embeddings = self.doc_embeddings,
vocab_embeddings = self.embeddings_vocab)
topicgpt2 = TopicGPT(api_key = self.api_key_openai,
n_topics= None,
document_embeddings = self.doc_embeddings,
vocab_embeddings = self.embeddings_vocab)
topicgpt3 = TopicGPT(api_key=self.api_key_openai,
n_topics = 1,
document_embeddings = self.doc_embeddings,
vocab_embeddings = self.embeddings_vocab,
n_topwords=10,
n_topwords_description=10,
topword_extraction_methods=["cosine_similarity"])
clusterer4 = Clustering_and_DimRed(
n_dims_umap = 10,
n_neighbors_umap = 20,
min_cluster_size_hdbscan = 10,
number_clusters_hdbscan= 10 # use only 10 clusters
)
topword_enhancement4 = TopwordEnhancement(api_key = self.api_key_openai)
topic_prompting4 = TopicPrompting(
api_key = self.api_key_openai,
enhancer = topword_enhancement4,
topic_lis = None
)
topicgpt4 = TopicGPT(api_key=self.api_key_openai,
n_topics= None,
document_embeddings = self.doc_embeddings,
vocab_embeddings = self.embeddings_vocab,
topic_prompting = topic_prompting4,
clusterer = clusterer4,
topword_extraction_methods=["tfidf"])
topic_gpt_list = [topicgpt1, topicgpt2, topicgpt3, topicgpt4]
for topic_gpt in topic_gpt_list:
instance_test(topic_gpt)
if __name__ == "__main__":
unittest.main()