Add BERTopic.
This commit is contained in:
@@ -0,0 +1,68 @@
|
||||
from bertopic.backend import BaseEmbedder
|
||||
from sklearn.utils.validation import check_is_fitted, NotFittedError
|
||||
|
||||
|
||||
class SklearnEmbedder(BaseEmbedder):
|
||||
"""Scikit-Learn based embedding model.
|
||||
|
||||
This component allows the usage of scikit-learn pipelines for generating document and
|
||||
word embeddings.
|
||||
|
||||
Arguments:
|
||||
pipe: A scikit-learn pipeline that can `.transform()` text.
|
||||
|
||||
Examples:
|
||||
Scikit-Learn is very flexible and it allows for many representations.
|
||||
A relatively simple pipeline is shown below.
|
||||
|
||||
```python
|
||||
from sklearn.pipeline import make_pipeline
|
||||
from sklearn.decomposition import TruncatedSVD
|
||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||
|
||||
from bertopic.backend import SklearnEmbedder
|
||||
|
||||
pipe = make_pipeline(
|
||||
TfidfVectorizer(),
|
||||
TruncatedSVD(100)
|
||||
)
|
||||
|
||||
sklearn_embedder = SklearnEmbedder(pipe)
|
||||
topic_model = BERTopic(embedding_model=sklearn_embedder)
|
||||
```
|
||||
|
||||
This pipeline first constructs a sparse representation based on TF/idf and then
|
||||
makes it dense by applying SVD. Alternatively, you might also construct something
|
||||
more elaborate. As long as you construct a scikit-learn compatible pipeline, you
|
||||
should be able to pass it to Bertopic.
|
||||
|
||||
!!! Warning
|
||||
One caveat to be aware of is that scikit-learns base `Pipeline` class does not
|
||||
support the `.partial_fit()`-API. If you have a pipeline that theoretically should
|
||||
be able to support online learning then you might want to explore
|
||||
the [scikit-partial](https://github.com/koaning/scikit-partial) project.
|
||||
"""
|
||||
|
||||
def __init__(self, pipe):
|
||||
super().__init__()
|
||||
self.pipe = pipe
|
||||
|
||||
def embed(self, documents, verbose=False):
|
||||
"""Embed a list of n documents/words into an n-dimensional
|
||||
matrix of embeddings.
|
||||
|
||||
Arguments:
|
||||
documents: A list of documents or words to be embedded
|
||||
verbose: No-op variable that's kept around to keep the API consistent. If you want to get feedback on training times, you should use the sklearn API.
|
||||
|
||||
Returns:
|
||||
Document/words embeddings with shape (n, m) with `n` documents/words
|
||||
that each have an embeddings size of `m`
|
||||
"""
|
||||
try:
|
||||
check_is_fitted(self.pipe)
|
||||
embeddings = self.pipe.transform(documents)
|
||||
except NotFittedError:
|
||||
embeddings = self.pipe.fit_transform(documents)
|
||||
|
||||
return embeddings
|
||||
Reference in New Issue
Block a user