Files
bettafish-company/LLMTopicDetection_BERTopic/bertopic/backend/_sentencetransformers.py
T
戒酒的李白 c5c530775e Add BERTopic.
2025-08-12 19:01:20 +08:00

86 lines
3.2 KiB
Python

import numpy as np
from typing import List, Union
from sentence_transformers import SentenceTransformer
from sentence_transformers.models import StaticEmbedding
from bertopic.backend import BaseEmbedder
class SentenceTransformerBackend(BaseEmbedder):
"""Sentence-transformers embedding model.
The sentence-transformers embedding model used for generating document and
word embeddings.
Arguments:
embedding_model: A sentence-transformers embedding model
model2vec: Indicates whether `embedding_model` is a model2vec model.
NOTE: Only works if `embedding_model` is a string.
Otherwise, you can pass the model2vec model directly to `embedding_model`.
Examples:
To create a model, you can load in a string pointing to a
sentence-transformers model:
```python
from bertopic.backend import SentenceTransformerBackend
sentence_model = SentenceTransformerBackend("all-MiniLM-L6-v2")
```
or you can instantiate a model yourself:
```python
from bertopic.backend import SentenceTransformerBackend
from sentence_transformers import SentenceTransformer
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
sentence_model = SentenceTransformerBackend(embedding_model)
```
If you want to use a model2vec model without having to install model2vec,
you can pass the model2vec model as a string:
```python
from bertopic.backend import SentenceTransformerBackend
from sentence_transformers import SentenceTransformer
embedding_model = SentenceTransformer("minishlab/potion-base-8M", model2vec=True)
sentence_model = SentenceTransformerBackend(embedding_model)
```
"""
def __init__(self, embedding_model: Union[str, SentenceTransformer], model2vec: bool = False):
super().__init__()
self._hf_model = None
if model2vec and isinstance(embedding_model, str):
static_embedding = StaticEmbedding.from_model2vec(embedding_model)
self.embedding_model = SentenceTransformer(modules=[static_embedding])
elif isinstance(embedding_model, SentenceTransformer):
self.embedding_model = embedding_model
elif isinstance(embedding_model, str):
self.embedding_model = SentenceTransformer(embedding_model)
self._hf_model = embedding_model
else:
raise ValueError(
"Please select a correct SentenceTransformers model: \n"
"`from sentence_transformers import SentenceTransformer` \n"
"`model = SentenceTransformer('all-MiniLM-L6-v2')`"
)
def embed(self, documents: List[str], verbose: bool = False) -> np.ndarray:
"""Embed a list of n documents/words into an n-dimensional
matrix of embeddings.
Arguments:
documents: A list of documents or words to be embedded
verbose: Controls the verbosity of the process
Returns:
Document/words embeddings with shape (n, m) with `n` documents/words
that each have an embeddings size of `m`
"""
embeddings = self.embedding_model.encode(documents, show_progress_bar=verbose)
return embeddings