Add BERTopic.
This commit is contained in:
@@ -0,0 +1,5 @@
|
||||
from ._base import BaseCluster
|
||||
|
||||
__all__ = [
|
||||
"BaseCluster",
|
||||
]
|
||||
@@ -0,0 +1,41 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
class BaseCluster:
|
||||
"""The Base Cluster class.
|
||||
|
||||
Using this class directly in BERTopic will make it skip
|
||||
over the cluster step. As a result, topics need to be passed
|
||||
to BERTopic in the form of its `y` parameter in order to create
|
||||
topic representations.
|
||||
|
||||
Examples:
|
||||
This will skip over the cluster step in BERTopic:
|
||||
|
||||
```python
|
||||
from bertopic import BERTopic
|
||||
from bertopic.cluster import BaseCluster
|
||||
|
||||
empty_cluster_model = BaseCluster()
|
||||
|
||||
topic_model = BERTopic(hdbscan_model=empty_cluster_model)
|
||||
```
|
||||
|
||||
Then, this class can be used to perform manual topic modeling.
|
||||
That is, topic modeling on a topics that were already generated before
|
||||
without the need to learn them:
|
||||
|
||||
```python
|
||||
topic_model.fit(docs, y=y)
|
||||
```
|
||||
"""
|
||||
|
||||
def fit(self, X, y=None):
|
||||
if y is not None:
|
||||
self.labels_ = y
|
||||
else:
|
||||
self.labels_ = None
|
||||
return self
|
||||
|
||||
def transform(self, X: np.ndarray) -> np.ndarray:
|
||||
return X
|
||||
@@ -0,0 +1,81 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
def hdbscan_delegator(model, func: str, embeddings: np.ndarray = None):
|
||||
"""Function used to select the HDBSCAN-like model for generating
|
||||
predictions and probabilities.
|
||||
|
||||
Arguments:
|
||||
model: The cluster model.
|
||||
func: The function to use. Options:
|
||||
- "approximate_predict"
|
||||
- "all_points_membership_vectors"
|
||||
- "membership_vector"
|
||||
embeddings: Input embeddings for "approximate_predict"
|
||||
and "membership_vector"
|
||||
"""
|
||||
try:
|
||||
import hdbscan
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
hdbscan = type("hdbscan", (), {"HDBSCAN": None})()
|
||||
|
||||
# Approximate predict
|
||||
if func == "approximate_predict":
|
||||
if isinstance(model, hdbscan.HDBSCAN):
|
||||
predictions, probabilities = hdbscan.approximate_predict(model, embeddings)
|
||||
return predictions, probabilities
|
||||
|
||||
str_type_model = str(type(model)).lower()
|
||||
if "cuml" in str_type_model and "hdbscan" in str_type_model:
|
||||
from cuml.cluster import hdbscan as cuml_hdbscan
|
||||
|
||||
predictions, probabilities = cuml_hdbscan.approximate_predict(model, embeddings)
|
||||
return predictions, probabilities
|
||||
|
||||
predictions = model.predict(embeddings)
|
||||
return predictions, None
|
||||
|
||||
# All points membership
|
||||
if func == "all_points_membership_vectors":
|
||||
if isinstance(model, hdbscan.HDBSCAN):
|
||||
return hdbscan.all_points_membership_vectors(model)
|
||||
|
||||
str_type_model = str(type(model)).lower()
|
||||
if "cuml" in str_type_model and "hdbscan" in str_type_model:
|
||||
from cuml.cluster import hdbscan as cuml_hdbscan
|
||||
|
||||
return cuml_hdbscan.all_points_membership_vectors(model)
|
||||
|
||||
return None
|
||||
|
||||
# membership_vector
|
||||
if func == "membership_vector":
|
||||
if isinstance(model, hdbscan.HDBSCAN):
|
||||
probabilities = hdbscan.membership_vector(model, embeddings)
|
||||
return probabilities
|
||||
|
||||
str_type_model = str(type(model)).lower()
|
||||
if "cuml" in str_type_model and "hdbscan" in str_type_model:
|
||||
from cuml.cluster import hdbscan as cuml_hdbscan
|
||||
|
||||
probabilities = cuml_hdbscan.membership_vector(model, embeddings)
|
||||
return probabilities
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def is_supported_hdbscan(model):
|
||||
"""Check whether the input model is a supported HDBSCAN-like model."""
|
||||
try:
|
||||
import hdbscan
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
hdbscan = type("hdbscan", (), {"HDBSCAN": None})()
|
||||
|
||||
if isinstance(model, hdbscan.HDBSCAN):
|
||||
return True
|
||||
|
||||
str_type_model = str(type(model)).lower()
|
||||
if "cuml" in str_type_model and "hdbscan" in str_type_model:
|
||||
return True
|
||||
|
||||
return False
|
||||
Reference in New Issue
Block a user