Add BERTopic.
This commit is contained in:
@@ -0,0 +1,155 @@
|
||||
import copy
|
||||
import pytest
|
||||
from bertopic import BERTopic
|
||||
import importlib.util
|
||||
|
||||
|
||||
def cuml_available():
|
||||
try:
|
||||
return importlib.util.find_spec("cuml") is not None
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
("base_topic_model"),
|
||||
("kmeans_pca_topic_model"),
|
||||
("custom_topic_model"),
|
||||
("merged_topic_model"),
|
||||
("reduced_topic_model"),
|
||||
("online_topic_model"),
|
||||
("supervised_topic_model"),
|
||||
("representation_topic_model"),
|
||||
("zeroshot_topic_model"),
|
||||
pytest.param(
|
||||
"cuml_base_topic_model",
|
||||
marks=pytest.mark.skipif(not cuml_available(), reason="cuML not available"),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_full_model(model, documents, request):
|
||||
"""Tests the entire pipeline in one go. This serves as a sanity check to see if the default
|
||||
settings result in a good separation of topics.
|
||||
|
||||
NOTE: This does not cover all cases but merely combines it all together
|
||||
"""
|
||||
topic_model = copy.deepcopy(request.getfixturevalue(model))
|
||||
if model == "base_topic_model":
|
||||
topic_model.save(
|
||||
"model_dir",
|
||||
serialization="pytorch",
|
||||
save_ctfidf=True,
|
||||
save_embedding_model="sentence-transformers/all-MiniLM-L6-v2",
|
||||
)
|
||||
topic_model = BERTopic.load("model_dir")
|
||||
|
||||
if model == "cuml_base_topic_model":
|
||||
assert "cuml" in str(type(topic_model.umap_model)).lower()
|
||||
assert "cuml" in str(type(topic_model.hdbscan_model)).lower()
|
||||
|
||||
topics = topic_model.topics_
|
||||
|
||||
for topic in set(topics):
|
||||
words = topic_model.get_topic(topic)[:10]
|
||||
assert len(words) == 10
|
||||
|
||||
for topic in topic_model.get_topic_freq().Topic:
|
||||
words = topic_model.get_topic(topic)[:10]
|
||||
assert len(words) == 10
|
||||
|
||||
assert len(topic_model.get_topic_freq()) > 2
|
||||
assert len(topic_model.get_topics()) == len(topic_model.get_topic_freq())
|
||||
|
||||
# Test extraction of document info
|
||||
document_info = topic_model.get_document_info(documents)
|
||||
assert len(document_info) == len(documents)
|
||||
|
||||
# Test transform
|
||||
doc = "This is a new document to predict."
|
||||
topics_test, probs_test = topic_model.transform([doc, doc])
|
||||
|
||||
assert len(topics_test) == 2
|
||||
|
||||
# Test zero-shot topic modeling
|
||||
if topic_model._is_zeroshot():
|
||||
if topic_model._outliers:
|
||||
assert set(topic_model.topic_labels_.keys()) == set(range(-1, len(topic_model.topic_labels_) - 1))
|
||||
else:
|
||||
assert set(topic_model.topic_labels_.keys()) == set(range(len(topic_model.topic_labels_)))
|
||||
|
||||
# Test topics over time
|
||||
timestamps = [i % 10 for i in range(len(documents))]
|
||||
topics_over_time = topic_model.topics_over_time(documents, timestamps)
|
||||
|
||||
assert topics_over_time.Frequency.sum() == len(documents)
|
||||
assert len(topics_over_time.Topic.unique()) == len(set(topics))
|
||||
|
||||
# Test hierarchical topics
|
||||
hier_topics = topic_model.hierarchical_topics(documents)
|
||||
|
||||
assert len(hier_topics) > 0
|
||||
assert hier_topics.Parent_ID.astype(int).min() > max(topics)
|
||||
|
||||
# Test creation of topic tree
|
||||
tree = topic_model.get_topic_tree(hier_topics, tight_layout=False)
|
||||
assert isinstance(tree, str)
|
||||
assert len(tree) > 10
|
||||
|
||||
# Test find topic
|
||||
similar_topics, similarity = topic_model.find_topics("query", top_n=2)
|
||||
assert len(similar_topics) == 2
|
||||
assert len(similarity) == 2
|
||||
assert max(similarity) <= 1
|
||||
|
||||
# Test topic reduction
|
||||
nr_topics = len(set(topics))
|
||||
nr_topics = 2 if nr_topics < 2 else nr_topics - 1
|
||||
topic_model.reduce_topics(documents, nr_topics=nr_topics)
|
||||
|
||||
assert len(topic_model.get_topic_freq()) == nr_topics
|
||||
assert len(topic_model.topics_) == len(topics)
|
||||
|
||||
# Test update topics
|
||||
topic = topic_model.get_topic(1)[:10]
|
||||
vectorizer_model = topic_model.vectorizer_model
|
||||
topic_model.update_topics(documents, n_gram_range=(2, 2))
|
||||
|
||||
updated_topic = topic_model.get_topic(1)[:10]
|
||||
|
||||
topic_model.update_topics(documents, vectorizer_model=vectorizer_model)
|
||||
original_topic = topic_model.get_topic(1)[:10]
|
||||
|
||||
assert topic != updated_topic
|
||||
if topic_model.representation_model is not None:
|
||||
assert topic != original_topic
|
||||
|
||||
# Test updating topic labels
|
||||
topic_labels = topic_model.generate_topic_labels(nr_words=3, topic_prefix=False, word_length=10, separator=", ")
|
||||
assert len(topic_labels) == len(set(topic_model.topics_))
|
||||
|
||||
# Test setting topic labels
|
||||
topic_model.set_topic_labels(topic_labels)
|
||||
assert topic_model.custom_labels_ == topic_labels
|
||||
|
||||
# Test merging topics
|
||||
freq = topic_model.get_topic_freq(0)
|
||||
topics_to_merge = [0, 1]
|
||||
topic_model.merge_topics(documents, topics_to_merge)
|
||||
assert freq < topic_model.get_topic_freq(0)
|
||||
|
||||
# Test reduction of outliers
|
||||
if -1 in topics:
|
||||
new_topics = topic_model.reduce_outliers(documents, topics, threshold=0.0)
|
||||
nr_outliers_topic_model = sum([1 for topic in topic_model.topics_ if topic == -1])
|
||||
nr_outliers_new_topics = sum([1 for topic in new_topics if topic == -1])
|
||||
|
||||
if topic_model._outliers == 1:
|
||||
assert nr_outliers_topic_model > nr_outliers_new_topics
|
||||
|
||||
# Combine models
|
||||
topic_model1 = BERTopic.load("model_dir")
|
||||
merged_model = BERTopic.merge_models([topic_model, topic_model1])
|
||||
|
||||
assert len(merged_model.get_topic_info()) > len(topic_model.get_topic_info())
|
||||
Reference in New Issue
Block a user