178 lines
6.7 KiB
Python
178 lines
6.7 KiB
Python
"""
|
|
This class is used to test the init and fit functions of the TopicGPT class
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import inspect
|
|
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
|
|
parentdir = os.path.dirname(currentdir)
|
|
|
|
sys.path.insert(0, f"{parentdir}/src")
|
|
from topicgpt.TopicGPT import TopicGPT
|
|
|
|
sys.path.insert(0, parentdir)
|
|
|
|
import openai
|
|
import pickle
|
|
|
|
import unittest
|
|
|
|
from src.topicgpt.TopicRepresentation import Topic
|
|
|
|
from src.topicgpt.Clustering import Clustering_and_DimRed
|
|
from src.topicgpt.TopwordEnhancement import TopwordEnhancement
|
|
from src.topicgpt.TopicPrompting import TopicPrompting
|
|
|
|
class TestTopicGPT_init_and_fit(unittest.TestCase):
|
|
"""
|
|
Test the init and fit functions of the TopicGPT class
|
|
"""
|
|
|
|
@classmethod
|
|
def setUpClass(cls, sample_size = 0.1):
|
|
"""
|
|
load the necessary data and only keep a sample of it
|
|
"""
|
|
print("Setting up class...")
|
|
cls.api_key_openai = os.environ.get('api_key')
|
|
# TODO: The 'openai.organization' option isn't read in the client API. You will need to pass it when you instantiate the client, e.g. 'OpenAI(organization=os.environ.get('OPENAI_ORG'))'
|
|
# openai.organization = os.environ.get('OPENAI_ORG')
|
|
|
|
with open("Data/Emebeddings/embeddings_20ng_raw.pkl", "rb") as f:
|
|
data_raw = pickle.load(f)
|
|
|
|
corpus = data_raw["corpus"]
|
|
doc_embeddings = data_raw["embeddings"]
|
|
|
|
n_docs = int(len(corpus) * sample_size)
|
|
cls.corpus = corpus[:n_docs]
|
|
cls.doc_embeddings = doc_embeddings[:n_docs]
|
|
|
|
print("Using {} out of {} documents".format(n_docs, len(data_raw["corpus"])))
|
|
|
|
with open("Data/Emebeddings/embeddings_20ng_vocab.pkl", "rb") as f:
|
|
cls.embeddings_vocab = pickle.load(f)
|
|
|
|
def test_init(self):
|
|
"""
|
|
test the init function of the TopicGPT class
|
|
"""
|
|
print("Testing init...")
|
|
topicgpt = TopicGPT(api_key = self.api_key_openai)
|
|
self.assertTrue(isinstance(topicgpt, TopicGPT))
|
|
|
|
topicgpt = TopicGPT(api_key = self.api_key_openai,
|
|
n_topics= 20)
|
|
self.assertTrue(isinstance(topicgpt, TopicGPT))
|
|
|
|
topicgpt = TopicGPT(api_key = self.api_key_openai,
|
|
n_topics= 20,
|
|
corpus_instruction="This is a corpus instruction",
|
|
document_embeddings = self.doc_embeddings,
|
|
vocab_embeddings= self.embeddings_vocab)
|
|
self.assertTrue(isinstance(topicgpt, TopicGPT))
|
|
|
|
# check if assertions are triggered
|
|
|
|
with self.assertRaises(AssertionError):
|
|
topicgpt = TopicGPT(api_key = None,
|
|
n_topics= 32,
|
|
openai_prompting_model="gpt-4",
|
|
max_number_of_tokens=8000,
|
|
corpus_instruction="This is a corpus instruction")
|
|
|
|
with self.assertRaises(AssertionError):
|
|
topicgpt = TopicGPT(api_key = self.api_key_openai,
|
|
n_topics= 0,
|
|
max_number_of_tokens=8000,
|
|
corpus_instruction="This is a corpus instruction")
|
|
|
|
with self.assertRaises(AssertionError):
|
|
topicgpt = TopicGPT(api_key = self.api_key_openai,
|
|
n_topics= 20,
|
|
max_number_of_tokens=0,
|
|
corpus_instruction="This is a corpus instruction")
|
|
|
|
def test_fit(self):
|
|
"""
|
|
test the fit function of the TopicGPT class
|
|
"""
|
|
print("Testing fit...")
|
|
|
|
def instance_test(topicgpt):
|
|
topicgpt.fit(self.corpus)
|
|
|
|
self.assertTrue(hasattr(topicgpt, "vocab"))
|
|
self.assertTrue(hasattr(topicgpt, "topic_lis"))
|
|
|
|
self.assertTrue(isinstance(topicgpt.vocab, list))
|
|
self.assertTrue(isinstance(topicgpt.vocab[0], str))
|
|
|
|
self.assertTrue(isinstance(topicgpt.topic_lis, list))
|
|
try:
|
|
self.assertTrue(type(topicgpt.topic_lis[0]) == Topic)
|
|
except AssertionError as e:
|
|
print(e)
|
|
print(type(topicgpt.topic_lis[0]))
|
|
print(topicgpt.topic_lis[0])
|
|
|
|
if topicgpt.n_topics is not None:
|
|
self.assertTrue(len(topicgpt.topic_lis) == topicgpt.n_topics)
|
|
|
|
self.assertTrue(topicgpt.topic_lis == topicgpt.topic_prompting.topic_lis)
|
|
self.assertTrue(topicgpt.vocab == topicgpt.topic_prompting.vocab)
|
|
self.assertTrue(topicgpt.vocab_embeddings == topicgpt.topic_prompting.vocab_embeddings)
|
|
|
|
|
|
topicgpt1 = TopicGPT(api_key = self.api_key_openai,
|
|
n_topics= 20,
|
|
document_embeddings = self.doc_embeddings,
|
|
vocab_embeddings = self.embeddings_vocab)
|
|
|
|
topicgpt2 = TopicGPT(api_key = self.api_key_openai,
|
|
n_topics= None,
|
|
document_embeddings = self.doc_embeddings,
|
|
vocab_embeddings = self.embeddings_vocab)
|
|
|
|
topicgpt3 = TopicGPT(api_key=self.api_key_openai,
|
|
n_topics = 1,
|
|
document_embeddings = self.doc_embeddings,
|
|
vocab_embeddings = self.embeddings_vocab,
|
|
n_topwords=10,
|
|
n_topwords_description=10,
|
|
topword_extraction_methods=["cosine_similarity"])
|
|
|
|
clusterer4 = Clustering_and_DimRed(
|
|
n_dims_umap = 10,
|
|
n_neighbors_umap = 20,
|
|
min_cluster_size_hdbscan = 10,
|
|
number_clusters_hdbscan= 10 # use only 10 clusters
|
|
)
|
|
|
|
topword_enhancement4 = TopwordEnhancement(api_key = self.api_key_openai)
|
|
topic_prompting4 = TopicPrompting(
|
|
api_key = self.api_key_openai,
|
|
enhancer = topword_enhancement4,
|
|
topic_lis = None
|
|
)
|
|
|
|
topicgpt4 = TopicGPT(api_key=self.api_key_openai,
|
|
n_topics= None,
|
|
document_embeddings = self.doc_embeddings,
|
|
vocab_embeddings = self.embeddings_vocab,
|
|
topic_prompting = topic_prompting4,
|
|
clusterer = clusterer4,
|
|
topword_extraction_methods=["tfidf"])
|
|
|
|
|
|
topic_gpt_list = [topicgpt1, topicgpt2, topicgpt3, topicgpt4]
|
|
|
|
for topic_gpt in topic_gpt_list:
|
|
instance_test(topic_gpt)
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main() |