Add BERTopic.
This commit is contained in:
@@ -0,0 +1,121 @@
|
||||
After reducing the dimensionality of our input embeddings, we need to cluster them into groups of similar embeddings to extract our topics.
|
||||
This process of clustering is quite important because the more performant our clustering technique the more accurate our topic representations are.
|
||||
|
||||
In BERTopic, we typically use HDBSCAN as it is quite capable of capturing structures with different densities. However, there is not one perfect
|
||||
clustering model and you might want to be using something entirely different for your use case. Moreover, what if a new state-of-the-art model
|
||||
is released tomorrow? We would like to be able to use that in BERTopic, right? Since BERTopic assumes some independence among steps, we can allow for this modularity:
|
||||
|
||||
<figure markdown>
|
||||

|
||||
<figcaption></figcaption>
|
||||
</figure>
|
||||
|
||||
As a result, the `hdbscan_model` parameter in BERTopic now allows for a variety of clustering models. To do so, the class should have
|
||||
the following attributes:
|
||||
|
||||
* `.fit(X)`
|
||||
* A function that can be used to fit the model
|
||||
* `.predict(X)`
|
||||
* A predict function that transforms the input to cluster labels
|
||||
* `.labels_`
|
||||
* The labels after fitting the model
|
||||
|
||||
|
||||
In other words, it should have the following structure:
|
||||
|
||||
```python
|
||||
class ClusterModel:
|
||||
def fit(self, X):
|
||||
self.labels_ = None
|
||||
return self
|
||||
|
||||
def predict(self, X):
|
||||
return X
|
||||
```
|
||||
|
||||
In this section, we will go through several examples of clustering algorithms and how they can be implemented.
|
||||
|
||||
|
||||
## **HDBSCAN**
|
||||
As a default, BERTopic uses HDBSCAN to perform its clustering. To use a HDBSCAN model with custom parameters,
|
||||
we simply define it and pass it to BERTopic:
|
||||
|
||||
```python
|
||||
from bertopic import BERTopic
|
||||
from hdbscan import HDBSCAN
|
||||
|
||||
hdbscan_model = HDBSCAN(min_cluster_size=15, metric='euclidean', cluster_selection_method='eom', prediction_data=True)
|
||||
topic_model = BERTopic(hdbscan_model=hdbscan_model)
|
||||
```
|
||||
|
||||
Here, we can define any parameters in HDBSCAN to optimize for the best performance based on whatever validation metrics you are using.
|
||||
|
||||
## **k-Means**
|
||||
Although HDBSCAN works quite well in BERTopic and is typically advised, you might want to be using k-Means instead.
|
||||
It allows you to select how many clusters you would like and forces every single point to be in a cluster. Therefore, no
|
||||
outliers will be created. This also has disadvantages. When you force every single point in a cluster, it will mean
|
||||
that the cluster is highly likely to contain noise which can hurt the topic representations. As a small tip, using
|
||||
the `vectorizer_model=CountVectorizer(stop_words="english")` helps quite a bit to then improve the topic representation.
|
||||
|
||||
Having said that, using k-Means is quite straightforward:
|
||||
|
||||
```python
|
||||
from bertopic import BERTopic
|
||||
from sklearn.cluster import KMeans
|
||||
|
||||
cluster_model = KMeans(n_clusters=50)
|
||||
topic_model = BERTopic(hdbscan_model=cluster_model)
|
||||
```
|
||||
|
||||
!!! note
|
||||
As you might have noticed, the `cluster_model` is passed to `hdbscan_model` which might be a bit confusing considering
|
||||
you are not passing an HDBSCAN model. For now, the name of the parameter is kept the same to adhere to the current
|
||||
state of the API. Changing the name could lead to deprecation issues, which I want to prevent as much as possible.
|
||||
|
||||
## **Agglomerative Clustering**
|
||||
Like k-Means, there are a bunch more clustering algorithms in `sklearn` that you can be using. Some of these models do
|
||||
not have a `.predict()` method but still can be used in BERTopic. However, using BERTopic's `.transform()` function
|
||||
will then give errors.
|
||||
|
||||
Here, we will demonstrate Agglomerative Clustering:
|
||||
|
||||
|
||||
```python
|
||||
from bertopic import BERTopic
|
||||
from sklearn.cluster import AgglomerativeClustering
|
||||
|
||||
cluster_model = AgglomerativeClustering(n_clusters=50)
|
||||
topic_model = BERTopic(hdbscan_model=cluster_model)
|
||||
```
|
||||
|
||||
|
||||
## **cuML HDBSCAN**
|
||||
|
||||
Although the original HDBSCAN implementation is an amazing technique, it may have difficulty handling large amounts of data. Instead,
|
||||
we can use [cuML](https://rapids.ai/start.html#rapids-release-selector) to speed up HDBSCAN through GPU acceleration:
|
||||
|
||||
```python
|
||||
from bertopic import BERTopic
|
||||
from cuml.cluster import HDBSCAN
|
||||
|
||||
hdbscan_model = HDBSCAN(min_samples=10, gen_min_span_tree=True, prediction_data=True)
|
||||
topic_model = BERTopic(hdbscan_model=hdbscan_model)
|
||||
```
|
||||
|
||||
The great thing about using cuML's HDBSCAN implementation is that it supports many features of the original implementation. In other words,
|
||||
`calculate_probabilities=True` also works!
|
||||
|
||||
!!! note
|
||||
As of the v0.13 release, it is not yet possible to calculate the topic-document probability matrix for unseen data (i.e., `.transform`) using cuML's HDBSCAN.
|
||||
However, it is still possible to calculate the topic-document probability matrix for the data on which the model was trained (i.e., `.fit` and `.fit_transform`).
|
||||
|
||||
!!! note
|
||||
If you want to install cuML together with BERTopic using Google Colab, you can run the following code:
|
||||
|
||||
```bash
|
||||
!pip install bertopic
|
||||
!pip install cudf-cu11 dask-cudf-cu11 --extra-index-url=https://pypi.nvidia.com
|
||||
!pip install cuml-cu11 --extra-index-url=https://pypi.nvidia.com
|
||||
!pip install cugraph-cu11 --extra-index-url=https://pypi.nvidia.com
|
||||
!pip install --upgrade cupy-cuda11x -f https://pip.cupy.dev/aarch64
|
||||
```
|
||||
@@ -0,0 +1,53 @@
|
||||
<svg width="445" height="268" viewBox="0 0 445 268" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<rect x="132" y="230" width="118" height="38" fill="#64B5F6"/>
|
||||
<rect x="224" y="220" width="20" height="8" fill="#64B5F6"/>
|
||||
<rect x="196" y="220" width="20" height="8" fill="#64B5F6"/>
|
||||
<rect x="168" y="220" width="20" height="8" fill="#64B5F6"/>
|
||||
<rect x="140" y="220" width="20" height="8" fill="#64B5F6"/>
|
||||
<text fill="white" xml:space="preserve" style="white-space: pre" font-family="Tahoma" font-size="20" font-weight="bold" letter-spacing="0em"><tspan x="158.256" y="257.939">SBERT</tspan></text>
|
||||
<rect x="132" y="190" width="118" height="38" fill="#E57373"/>
|
||||
<rect x="224" y="180" width="20" height="8" fill="#E57373"/>
|
||||
<rect x="196" y="180" width="20" height="8" fill="#E57373"/>
|
||||
<rect x="168" y="180" width="20" height="8" fill="#E57373"/>
|
||||
<rect x="140" y="180" width="20" height="8" fill="#E57373"/>
|
||||
<text fill="white" xml:space="preserve" style="white-space: pre" font-family="Tahoma" font-size="20" font-weight="bold" letter-spacing="0em"><tspan x="161.254" y="217.939">UMAP</tspan></text>
|
||||
<rect y="150" width="118" height="38" fill="#4DB6AC"/>
|
||||
<rect x="92" y="140" width="20" height="8" fill="#4DB6AC"/>
|
||||
<rect x="64" y="140" width="20" height="8" fill="#4DB6AC"/>
|
||||
<rect x="36" y="140" width="20" height="8" fill="#4DB6AC"/>
|
||||
<rect x="8" y="140" width="20" height="8" fill="#4DB6AC"/>
|
||||
<text fill="white" xml:space="preserve" style="white-space: pre" font-family="Tahoma" font-size="20" font-weight="bold" letter-spacing="0em"><tspan x="9.3418" y="177.939">HDBSCAN</tspan></text>
|
||||
<rect x="132" y="90" width="118" height="38" fill="#FFD54F"/>
|
||||
<rect x="224" y="80" width="20" height="8" fill="#FFD54F"/>
|
||||
<rect x="196" y="80" width="20" height="8" fill="#FFD54F"/>
|
||||
<rect x="168" y="80" width="20" height="8" fill="#FFD54F"/>
|
||||
<rect x="140" y="80" width="20" height="8" fill="#FFD54F"/>
|
||||
<text fill="white" xml:space="preserve" style="white-space: pre" font-family="Tahoma" font-size="13" font-weight="bold" letter-spacing="0em"><tspan x="138.346" y="114.161">CountVectorizer</tspan></text>
|
||||
<rect x="132" y="50" width="118" height="38" fill="#90A4AE"/>
|
||||
<rect x="224" y="40" width="20" height="8" fill="#90A4AE"/>
|
||||
<rect x="196" y="40" width="20" height="8" fill="#90A4AE"/>
|
||||
<rect x="168" y="40" width="20" height="8" fill="#90A4AE"/>
|
||||
<rect x="140" y="40" width="20" height="8" fill="#90A4AE"/>
|
||||
<text fill="white" xml:space="preserve" style="white-space: pre" font-family="Tahoma" font-size="20" font-weight="bold" letter-spacing="0em"><tspan x="146.938" y="77.9395">c-TF-IDF</tspan></text>
|
||||
<rect x="132" y="10" width="118" height="38" fill="#3F51B5"/>
|
||||
<rect x="224" width="20" height="8" fill="#3F51B5"/>
|
||||
<rect x="196" width="20" height="8" fill="#3F51B5"/>
|
||||
<rect x="168" width="20" height="8" fill="#3F51B5"/>
|
||||
<rect x="140" width="20" height="8" fill="#3F51B5"/>
|
||||
<text fill="white" xml:space="preserve" style="white-space: pre" font-family="Tahoma" font-size="14" font-weight="bold" letter-spacing="0em"><tspan x="161.065" y="25.0576">Optional </tspan><tspan x="150.271" y="42.0576">Fine-tuning</tspan></text>
|
||||
<rect x="132" y="150" width="118" height="38" fill="#4DB6AC"/>
|
||||
<rect x="224" y="140" width="20" height="8" fill="#4DB6AC"/>
|
||||
<rect x="196" y="140" width="20" height="8" fill="#4DB6AC"/>
|
||||
<rect x="168" y="140" width="20" height="8" fill="#4DB6AC"/>
|
||||
<rect x="140" y="140" width="20" height="8" fill="#4DB6AC"/>
|
||||
<text fill="white" xml:space="preserve" style="white-space: pre" font-family="Tahoma" font-size="20" font-weight="bold" letter-spacing="0em"><tspan x="148.246" y="177.939">k-Means</tspan></text>
|
||||
<rect x="327" y="150" width="118" height="38" fill="#4DB6AC"/>
|
||||
<rect x="419" y="140" width="20" height="8" fill="#4DB6AC"/>
|
||||
<rect x="391" y="140" width="20" height="8" fill="#4DB6AC"/>
|
||||
<rect x="363" y="140" width="20" height="8" fill="#4DB6AC"/>
|
||||
<rect x="335" y="140" width="20" height="8" fill="#4DB6AC"/>
|
||||
<text fill="white" xml:space="preserve" style="white-space: pre" font-family="Tahoma" font-size="20" font-weight="bold" letter-spacing="0em"><tspan x="351.709" y="177.939">BIRCH</tspan></text>
|
||||
<circle cx="266.5" cy="168.5" r="5.5" fill="black"/>
|
||||
<circle cx="285.5" cy="168.5" r="5.5" fill="black"/>
|
||||
<circle cx="307.5" cy="168.5" r="5.5" fill="black"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 4.2 KiB |
Reference in New Issue
Block a user