Add BERTopic.

This commit is contained in:
戒酒的李白
2025-08-12 19:01:20 +08:00
parent e2323d579c
commit c5c530775e
256 changed files with 28666 additions and 0 deletions
@@ -0,0 +1,76 @@
from bertopic._utils import NotInstalled
from bertopic.representation._cohere import Cohere
from bertopic.representation._base import BaseRepresentation
from bertopic.representation._keybert import KeyBERTInspired
from bertopic.representation._mmr import MaximalMarginalRelevance
# Llama CPP Generator
try:
from bertopic.representation._llamacpp import LlamaCPP
except ModuleNotFoundError:
msg = "`pip install llama-cpp-python` \n\n"
LlamaCPP = NotInstalled("llama.cpp", "llama-cpp-python", custom_msg=msg)
# Text Generation using transformers
try:
from bertopic.representation._textgeneration import TextGeneration
except ModuleNotFoundError:
msg = "`pip install bertopic` without `--no-deps` \n\n"
TextGeneration = NotInstalled("TextGeneration", "transformers", custom_msg=msg)
# Zero-shot classification using transformers
try:
from bertopic.representation._zeroshot import ZeroShotClassification
except ModuleNotFoundError:
msg = "`pip install bertopic` without `--no-deps` \n\n"
ZeroShotClassification = NotInstalled("ZeroShotClassification", "transformers", custom_msg=msg)
# OpenAI Generator
try:
from bertopic.representation._openai import OpenAI
except ModuleNotFoundError:
msg = "`pip install openai` \n\n"
OpenAI = NotInstalled("OpenAI", "openai", custom_msg=msg)
# LiteLLM Generator
try:
from bertopic.representation._litellm import LiteLLM
except ModuleNotFoundError:
msg = "`pip install litellm` \n\n"
LiteLLM = NotInstalled("LiteLLM", "litellm", custom_msg=msg)
# LangChain Generator
try:
from bertopic.representation._langchain import LangChain
except ModuleNotFoundError:
msg = "`pip install langchain` \n\n"
LangChain = NotInstalled("langchain", "langchain", custom_msg=msg)
# POS using Spacy
try:
from bertopic.representation._pos import PartOfSpeech
except ModuleNotFoundError:
PartOfSpeech = NotInstalled("Part of Speech with Spacy", "spacy")
# Multimodal
try:
from bertopic.representation._visual import VisualRepresentation
except ModuleNotFoundError:
VisualRepresentation = NotInstalled("a visual representation model", "vision")
__all__ = [
"BaseRepresentation",
"TextGeneration",
"ZeroShotClassification",
"KeyBERTInspired",
"PartOfSpeech",
"MaximalMarginalRelevance",
"Cohere",
"OpenAI",
"LangChain",
"LiteLLM",
"LlamaCPP",
"VisualRepresentation",
]
@@ -0,0 +1,40 @@
import pandas as pd
from scipy.sparse import csr_matrix
from sklearn.base import BaseEstimator
from typing import Mapping, List, Tuple
class BaseRepresentation(BaseEstimator):
"""The base representation model for fine-tuning topic representations."""
def extract_topics(
self,
topic_model,
documents: pd.DataFrame,
c_tf_idf: csr_matrix,
topics: Mapping[str, List[Tuple[str, float]]],
) -> Mapping[str, List[Tuple[str, float]]]:
"""Extract topics.
Each representation model that inherits this class will have
its arguments (topic_model, documents, c_tf_idf, topics)
automatically passed. Therefore, the representation model
will only have access to the information about topics related
to those arguments.
Arguments:
topic_model: The BERTopic model that is fitted until topic
representations are calculated.
documents: A dataframe with columns "Document" and "Topic"
that contains all documents with each corresponding
topic.
c_tf_idf: A c-TF-IDF representation that is typically
identical to `topic_model.c_tf_idf_` except for
dynamic, class-based, and hierarchical topic modeling
where it is calculated on a subset of the documents.
topics: A dictionary with topic (key) and tuple of word and
weight (value) as calculated by c-TF-IDF. This is the
default topics that are returned if no representation
model is used.
"""
return topic_model.topic_representations_
@@ -0,0 +1,209 @@
import time
import pandas as pd
from tqdm import tqdm
from scipy.sparse import csr_matrix
from typing import Mapping, List, Tuple, Union, Callable
from bertopic.representation._base import BaseRepresentation
from bertopic.representation._utils import truncate_document, validate_truncate_document_parameters
DEFAULT_PROMPT = """
This is a list of texts where each collection of texts describe a topic. After each collection of texts, the name of the topic they represent is mentioned as a short-highly-descriptive title
---
Topic:
Sample texts from this topic:
- Traditional diets in most cultures were primarily plant-based with a little meat on top, but with the rise of industrial style meat production and factory farming, meat has become a staple food.
- Meat, but especially beef, is the word food in terms of emissions.
- Eating meat doesn't make you a bad person, not eating meat doesn't make you a good one.
Keywords: meat beef eat eating emissions steak food health processed chicken
Topic name: Environmental impacts of eating meat
---
Topic:
Sample texts from this topic:
- I have ordered the product weeks ago but it still has not arrived!
- The website mentions that it only takes a couple of days to deliver but I still have not received mine.
- I got a message stating that I received the monitor but that is not true!
- It took a month longer to deliver than was advised...
Keywords: deliver weeks product shipping long delivery received arrived arrive week
Topic name: Shipping and delivery issues
---
Topic:
Sample texts from this topic:
[DOCUMENTS]
Keywords: [KEYWORDS]
Topic name:"""
DEFAULT_SYSTEM_PROMPT = "You are an assistant that extracts high-level topics from texts."
class Cohere(BaseRepresentation):
"""Use the Cohere API to generate topic labels based on their
generative model.
Find more about their models here:
https://docs.cohere.ai/docs
Arguments:
client: A `cohere.Client`
model: Model to use within Cohere, defaults to `"xlarge"`.
prompt: The prompt to be used in the model. If no prompt is given,
`self.default_prompt_` is used instead.
NOTE: Use `"[KEYWORDS]"` and `"[DOCUMENTS]"` in the prompt
to decide where the keywords and documents need to be
inserted.
system_prompt: The system prompt to be used in the model. If no system prompt is given,
`self.default_system_prompt_` is used instead.
delay_in_seconds: The delay in seconds between consecutive prompts
in order to prevent RateLimitErrors.
nr_docs: The number of documents to pass to OpenAI if a prompt
with the `["DOCUMENTS"]` tag is used.
diversity: The diversity of documents to pass to OpenAI.
Accepts values between 0 and 1. A higher
values results in passing more diverse documents
whereas lower values passes more similar documents.
doc_length: The maximum length of each document. If a document is longer,
it will be truncated. If None, the entire document is passed.
tokenizer: The tokenizer used to calculate to split the document into segments
used to count the length of a document.
* If tokenizer is 'char', then the document is split up
into characters which are counted to adhere to `doc_length`
* If tokenizer is 'whitespace', the document is split up
into words separated by whitespaces. These words are counted
and truncated depending on `doc_length`
* If tokenizer is 'vectorizer', then the internal CountVectorizer
is used to tokenize the document. These tokens are counted
and truncated depending on `doc_length`
* If tokenizer is a callable, then that callable is used to tokenize
the document. These tokens are counted and truncated depending
on `doc_length`
Usage:
To use this, you will need to install cohere first:
`pip install cohere`
Then, get yourself an API key and use Cohere's API as follows:
```python
import cohere
from bertopic.representation import Cohere
from bertopic import BERTopic
# Create your representation model
co = cohere.Client(my_api_key)
representation_model = Cohere(co)
# Use the representation model in BERTopic on top of the default pipeline
topic_model = BERTopic(representation_model=representation_model)
```
You can also use a custom prompt:
```python
prompt = "I have the following documents: [DOCUMENTS]. What topic do they contain?"
representation_model = Cohere(co, prompt=prompt)
```
"""
def __init__(
self,
client,
model: str = "command-r",
prompt: str = None,
system_prompt: str = None,
delay_in_seconds: float = None,
nr_docs: int = 4,
diversity: float = None,
doc_length: int = None,
tokenizer: Union[str, Callable] = None,
):
self.client = client
self.model = model
self.prompt = prompt if prompt is not None else DEFAULT_PROMPT
self.system_prompt = system_prompt if system_prompt is not None else DEFAULT_SYSTEM_PROMPT
self.default_prompt_ = DEFAULT_PROMPT
self.default_system_prompt_ = DEFAULT_SYSTEM_PROMPT
self.delay_in_seconds = delay_in_seconds
self.nr_docs = nr_docs
self.diversity = diversity
self.doc_length = doc_length
self.tokenizer = tokenizer
validate_truncate_document_parameters(self.tokenizer, self.doc_length)
self.prompts_ = []
def extract_topics(
self,
topic_model,
documents: pd.DataFrame,
c_tf_idf: csr_matrix,
topics: Mapping[str, List[Tuple[str, float]]],
) -> Mapping[str, List[Tuple[str, float]]]:
"""Extract topics.
Arguments:
topic_model: Not used
documents: Not used
c_tf_idf: Not used
topics: The candidate topics as calculated with c-TF-IDF
Returns:
updated_topics: Updated topic representations
"""
# Extract the top 4 representative documents per topic
repr_docs_mappings, _, _, _ = topic_model._extract_representative_docs(
c_tf_idf, documents, topics, 500, self.nr_docs, self.diversity
)
# Generate using Cohere's Language Model
updated_topics = {}
for topic, docs in tqdm(repr_docs_mappings.items(), disable=not topic_model.verbose):
truncated_docs = [truncate_document(topic_model, self.doc_length, self.tokenizer, doc) for doc in docs]
prompt = self._create_prompt(truncated_docs, topic, topics)
self.prompts_.append(prompt)
# Delay
if self.delay_in_seconds:
time.sleep(self.delay_in_seconds)
request = self.client.chat(
model=self.model,
preamble=self.system_prompt,
message=prompt,
max_tokens=50,
stop_sequences=["\n"],
)
label = request.text.strip()
updated_topics[topic] = [(label, 1)] + [("", 0) for _ in range(9)]
return updated_topics
def _create_prompt(self, docs, topic, topics):
keywords = list(zip(*topics[topic]))[0]
# Use the Default Chat Prompt
if self.prompt == DEFAULT_PROMPT:
prompt = self.prompt.replace("[KEYWORDS]", ", ".join(keywords))
prompt = self._replace_documents(prompt, docs)
# Use a custom prompt that leverages keywords, documents or both using
# custom tags, namely [KEYWORDS] and [DOCUMENTS] respectively
else:
prompt = self.prompt
if "[KEYWORDS]" in prompt:
prompt = prompt.replace("[KEYWORDS]", ", ".join(keywords))
if "[DOCUMENTS]" in prompt:
prompt = self._replace_documents(prompt, docs)
return prompt
@staticmethod
def _replace_documents(prompt, docs):
to_replace = ""
for doc in docs:
to_replace += f"- {doc}\n"
prompt = prompt.replace("[DOCUMENTS]", to_replace)
return prompt
@@ -0,0 +1,222 @@
import numpy as np
import pandas as pd
from packaging import version
from scipy.sparse import csr_matrix
from typing import Mapping, List, Tuple, Union
from sklearn.metrics.pairwise import cosine_similarity
from bertopic.representation._base import BaseRepresentation
from sklearn import __version__ as sklearn_version
class KeyBERTInspired(BaseRepresentation):
def __init__(
self,
top_n_words: int = 10,
nr_repr_docs: int = 5,
nr_samples: int = 500,
nr_candidate_words: int = 100,
random_state: int = 42,
):
"""Use a KeyBERT-like model to fine-tune the topic representations.
The algorithm follows KeyBERT but does some optimization in
order to speed up inference.
The steps are as follows. First, we extract the top n representative
documents per topic. To extract the representative documents, we
randomly sample a number of candidate documents per cluster
which is controlled by the `nr_samples` parameter. Then,
the top n representative documents are extracted by calculating
the c-TF-IDF representation for the candidate documents and finding,
through cosine similarity, which are closest to the topic c-TF-IDF representation.
Next, the top n words per topic are extracted based on their
c-TF-IDF representation, which is controlled by the `nr_repr_docs`
parameter.
Then, we extract the embeddings for words and representative documents
and create topic embeddings by averaging the representative documents.
Finally, the most similar words to each topic are extracted by
calculating the cosine similarity between word and topic embeddings.
Arguments:
top_n_words: The top n words to extract per topic.
nr_repr_docs: The number of representative documents to extract per cluster.
nr_samples: The number of candidate documents to extract per cluster.
nr_candidate_words: The number of candidate words per cluster.
random_state: The random state for randomly sampling candidate documents.
Usage:
```python
from bertopic.representation import KeyBERTInspired
from bertopic import BERTopic
# Create your representation model
representation_model = KeyBERTInspired()
# Use the representation model in BERTopic on top of the default pipeline
topic_model = BERTopic(representation_model=representation_model)
```
"""
self.top_n_words = top_n_words
self.nr_repr_docs = nr_repr_docs
self.nr_samples = nr_samples
self.nr_candidate_words = nr_candidate_words
self.random_state = random_state
def extract_topics(
self,
topic_model,
documents: pd.DataFrame,
c_tf_idf: csr_matrix,
topics: Mapping[str, List[Tuple[str, float]]],
embeddings: np.ndarray = None,
) -> Mapping[str, List[Tuple[str, float]]]:
"""Extract topics.
Arguments:
topic_model: A BERTopic model
documents: All input documents
c_tf_idf: The topic c-TF-IDF representation
topics: The candidate topics as calculated with c-TF-IDF
embeddings: Pre-trained document embeddings. These can be used
instead of an embedding model
Returns:
updated_topics: Updated topic representations
"""
# We extract the top n representative documents per class
_, representative_docs, repr_doc_indices, _ = topic_model._extract_representative_docs(
c_tf_idf, documents, topics, self.nr_samples, self.nr_repr_docs
)
# If document embeddings are precomputed, extract the embeddings of the representative documents based on repr_doc_indices
repr_embeddings = None
if embeddings is not None:
repr_embeddings = [embeddings[index] for index in np.concatenate(repr_doc_indices)]
# We extract the top n words per class
topics = self._extract_candidate_words(topic_model, c_tf_idf, topics)
# We calculate the similarity between word and document embeddings and create
# topic embeddings from the representative document embeddings
sim_matrix, words = self._extract_embeddings(
topic_model, topics, representative_docs, repr_doc_indices, repr_embeddings
)
# Find the best matching words based on the similarity matrix for each topic
updated_topics = self._extract_top_words(words, topics, sim_matrix)
return updated_topics
def _extract_candidate_words(
self,
topic_model,
c_tf_idf: csr_matrix,
topics: Mapping[str, List[Tuple[str, float]]],
) -> Mapping[str, List[Tuple[str, float]]]:
"""For each topic, extract candidate words based on the c-TF-IDF
representation.
Arguments:
topic_model: A BERTopic model
c_tf_idf: The topic c-TF-IDF representation
topics: The top words per topic
Returns:
topics: The `self.top_n_words` per topic
"""
labels = [int(label) for label in sorted(list(topics.keys()))]
# Scikit-Learn Deprecation: get_feature_names is deprecated in 1.0
# and will be removed in 1.2. Please use get_feature_names_out instead.
if version.parse(sklearn_version) >= version.parse("1.0.0"):
words = topic_model.vectorizer_model.get_feature_names_out()
else:
words = topic_model.vectorizer_model.get_feature_names()
indices = topic_model._top_n_idx_sparse(c_tf_idf, self.nr_candidate_words)
scores = topic_model._top_n_values_sparse(c_tf_idf, indices)
sorted_indices = np.argsort(scores, 1)
indices = np.take_along_axis(indices, sorted_indices, axis=1)
scores = np.take_along_axis(scores, sorted_indices, axis=1)
# Get top 30 words per topic based on c-TF-IDF score
topics = {
label: [
(words[word_index], score) if word_index is not None and score > 0 else ("", 0.00001)
for word_index, score in zip(indices[index][::-1], scores[index][::-1])
]
for index, label in enumerate(labels)
}
topics = {label: list(zip(*values[: self.nr_candidate_words]))[0] for label, values in topics.items()}
return topics
def _extract_embeddings(
self,
topic_model,
topics: Mapping[str, List[Tuple[str, float]]],
representative_docs: List[str],
repr_doc_indices: List[List[int]],
repr_embeddings: np.ndarray = None,
) -> Union[np.ndarray, List[str]]:
"""Extract the representative document embeddings and create topic embeddings.
Then extract word embeddings and calculate the cosine similarity between topic
embeddings and the word embeddings. Topic embeddings are the average of
representative document embeddings.
Arguments:
topic_model: A BERTopic model
topics: The top words per topic
representative_docs: A flat list of representative documents
repr_doc_indices: The indices of representative documents
that belong to each topic
repr_embeddings: Embeddings of respective representative_docs
Returns:
sim: The similarity matrix between word and topic embeddings
vocab: The complete vocabulary of input documents
"""
# Calculate representative document embeddings if there are no precomputed embeddings.
if repr_embeddings is None:
repr_embeddings = topic_model._extract_embeddings(representative_docs, method="document", verbose=False)
topic_embeddings = [np.mean(repr_embeddings[i[0] : i[-1] + 1], axis=0) for i in repr_doc_indices]
# Calculate word embeddings and extract best matching with updated topic_embeddings
vocab = list(set([word for words in topics.values() for word in words]))
word_embeddings = topic_model._extract_embeddings(vocab, method="document", verbose=False)
sim = cosine_similarity(topic_embeddings, word_embeddings)
return sim, vocab
def _extract_top_words(
self,
vocab: List[str],
topics: Mapping[str, List[Tuple[str, float]]],
sim: np.ndarray,
) -> Mapping[str, List[Tuple[str, float]]]:
"""Extract the top n words per topic based on the
similarity matrix between topics and words.
Arguments:
vocab: The complete vocabulary of input documents
labels: All topic labels
topics: The top words per topic
sim: The similarity matrix between word and topic embeddings
Returns:
updated_topics: The updated topic representations
"""
labels = [int(label) for label in sorted(list(topics.keys()))]
updated_topics = {}
for i, topic in enumerate(labels):
indices = [vocab.index(word) for word in topics[topic]]
values = sim[:, indices][i]
word_indices = [indices[index] for index in np.argsort(values)[-self.top_n_words :]]
updated_topics[topic] = [
(vocab[index], val) for val, index in zip(np.sort(values)[-self.top_n_words :], word_indices)
][::-1]
return updated_topics
@@ -0,0 +1,213 @@
import pandas as pd
from langchain.docstore.document import Document
from scipy.sparse import csr_matrix
from typing import Callable, Mapping, List, Tuple, Union
from bertopic.representation._base import BaseRepresentation
from bertopic.representation._utils import truncate_document, validate_truncate_document_parameters
DEFAULT_PROMPT = "What are these documents about? Please give a single label."
class LangChain(BaseRepresentation):
"""Using chains in langchain to generate topic labels.
The classic example uses `langchain.chains.question_answering.load_qa_chain`.
This returns a chain that takes a list of documents and a question as input.
You can also use Runnables such as those composed using the LangChain Expression Language.
Arguments:
chain: The langchain chain or Runnable with a `batch` method.
Input keys must be `input_documents` and `question`.
Output key must be `output_text`.
prompt: The prompt to be used in the model. If no prompt is given,
`self.default_prompt_` is used instead.
NOTE: Use `"[KEYWORDS]"` in the prompt
to decide where the keywords need to be
inserted. Keywords won't be included unless
indicated. Unlike other representation models,
Langchain does not use the `"[DOCUMENTS]"` tag
to insert documents into the prompt. The load_qa_chain function
formats the representative documents within the prompt.
nr_docs: The number of documents to pass to LangChain
diversity: The diversity of documents to pass to LangChain.
Accepts values between 0 and 1. A higher
values results in passing more diverse documents
whereas lower values passes more similar documents.
doc_length: The maximum length of each document. If a document is longer,
it will be truncated. If None, the entire document is passed.
tokenizer: The tokenizer used to calculate to split the document into segments
used to count the length of a document.
* If tokenizer is 'char', then the document is split up
into characters which are counted to adhere to `doc_length`
* If tokenizer is 'whitespace', the document is split up
into words separated by whitespaces. These words are counted
and truncated depending on `doc_length`
* If tokenizer is 'vectorizer', then the internal CountVectorizer
is used to tokenize the document. These tokens are counted
and truncated depending on `doc_length`. They are decoded with
whitespaces.
* If tokenizer is a callable, then that callable is used to tokenize
the document. These tokens are counted and truncated depending
on `doc_length`
chain_config: The configuration for the langchain chain. Can be used to set options
like max_concurrency to avoid rate limiting errors.
Usage:
To use this, you will need to install the langchain package first.
Additionally, you will need an underlying LLM to support langchain,
like openai:
`pip install langchain`
`pip install openai`
Then, you can create your chain as follows:
```python
from langchain.chains.question_answering import load_qa_chain
from langchain.llms import OpenAI
chain = load_qa_chain(OpenAI(temperature=0, openai_api_key=my_openai_api_key), chain_type="stuff")
```
Finally, you can pass the chain to BERTopic as follows:
```python
from bertopic.representation import LangChain
# Create your representation model
representation_model = LangChain(chain)
# Use the representation model in BERTopic on top of the default pipeline
topic_model = BERTopic(representation_model=representation_model)
```
You can also use a custom prompt:
```python
prompt = "What are these documents about? Please give a single label."
representation_model = LangChain(chain, prompt=prompt)
```
You can also use a Runnable instead of a chain.
The example below uses the LangChain Expression Language:
```python
from bertopic.representation import LangChain
from langchain.chains.question_answering import load_qa_chain
from langchain.chat_models import ChatAnthropic
from langchain.schema.document import Document
from langchain.schema.runnable import RunnablePassthrough
from langchain_experimental.data_anonymizer.presidio import PresidioReversibleAnonymizer
prompt = ...
llm = ...
# We will construct a special privacy-preserving chain using Microsoft Presidio
pii_handler = PresidioReversibleAnonymizer(analyzed_fields=["PERSON"])
chain = (
{
"input_documents": (
lambda inp: [
Document(
page_content=pii_handler.anonymize(
d.page_content,
language="en",
),
)
for d in inp["input_documents"]
]
),
"question": RunnablePassthrough(),
}
| load_qa_chain(representation_llm, chain_type="stuff")
| (lambda output: {"output_text": pii_handler.deanonymize(output["output_text"])})
)
representation_model = LangChain(chain, prompt=representation_prompt)
```
"""
def __init__(
self,
chain,
prompt: str = None,
nr_docs: int = 4,
diversity: float = None,
doc_length: int = None,
tokenizer: Union[str, Callable] = None,
chain_config=None,
):
self.chain = chain
self.prompt = prompt if prompt is not None else DEFAULT_PROMPT
self.default_prompt_ = DEFAULT_PROMPT
self.chain_config = chain_config
self.nr_docs = nr_docs
self.diversity = diversity
self.doc_length = doc_length
self.tokenizer = tokenizer
validate_truncate_document_parameters(self.tokenizer, self.doc_length)
def extract_topics(
self,
topic_model,
documents: pd.DataFrame,
c_tf_idf: csr_matrix,
topics: Mapping[str, List[Tuple[str, float]]],
) -> Mapping[str, List[Tuple[str, int]]]:
"""Extract topics.
Arguments:
topic_model: A BERTopic model
documents: All input documents
c_tf_idf: The topic c-TF-IDF representation
topics: The candidate topics as calculated with c-TF-IDF
Returns:
updated_topics: Updated topic representations
"""
# Extract the top 4 representative documents per topic
repr_docs_mappings, _, _, _ = topic_model._extract_representative_docs(
c_tf_idf=c_tf_idf,
documents=documents,
topics=topics,
nr_samples=500,
nr_repr_docs=self.nr_docs,
diversity=self.diversity,
)
# Generate label using langchain's batch functionality
chain_docs: List[List[Document]] = [
[
Document(page_content=truncate_document(topic_model, self.doc_length, self.tokenizer, doc))
for doc in docs
]
for docs in repr_docs_mappings.values()
]
# `self.chain` must take `input_documents` and `question` as input keys
# Use a custom prompt that leverages keywords, using the tag: [KEYWORDS]
if "[KEYWORDS]" in self.prompt:
prompts = []
for topic in topics:
keywords = list(zip(*topics[topic]))[0]
prompt = self.prompt.replace("[KEYWORDS]", ", ".join(keywords))
prompts.append(prompt)
inputs = [{"input_documents": docs, "question": prompt} for docs, prompt in zip(chain_docs, prompts)]
else:
inputs = [{"input_documents": docs, "question": self.prompt} for docs in chain_docs]
# `self.chain` must return a dict with an `output_text` key
# same output key as the `StuffDocumentsChain` returned by `load_qa_chain`
outputs = self.chain.batch(inputs=inputs, config=self.chain_config)
labels = [output["output_text"].strip() for output in outputs]
updated_topics = {
topic: [(label, 1)] + [("", 0) for _ in range(9)] for topic, label in zip(repr_docs_mappings.keys(), labels)
}
return updated_topics
@@ -0,0 +1,176 @@
import time
from litellm import completion
import pandas as pd
from scipy.sparse import csr_matrix
from typing import Mapping, List, Tuple, Any
from bertopic.representation._base import BaseRepresentation
from bertopic.representation._utils import retry_with_exponential_backoff
DEFAULT_PROMPT = """
I have a topic that contains the following documents:
[DOCUMENTS]
The topic is described by the following keywords: [KEYWORDS]
Based on the information above, extract a short topic label in the following format:
topic: <topic label>
"""
class LiteLLM(BaseRepresentation):
"""Using the LiteLLM API to generate topic labels.
For an overview of models see:
https://docs.litellm.ai/docs/providers
Arguments:
model: Model to use. Defaults to OpenAI's "gpt-3.5-turbo".
generator_kwargs: Kwargs passed to `litellm.completion`.
prompt: The prompt to be used in the model. If no prompt is given,
`self.default_prompt_` is used instead.
NOTE: Use `"[KEYWORDS]"` and `"[DOCUMENTS]"` in the prompt
to decide where the keywords and documents need to be
inserted.
delay_in_seconds: The delay in seconds between consecutive prompts
in order to prevent RateLimitErrors.
exponential_backoff: Retry requests with a random exponential backoff.
A short sleep is used when a rate limit error is hit,
then the requests is retried. Increase the sleep length
if errors are hit until 10 unsuccesfull requests.
If True, overrides `delay_in_seconds`.
nr_docs: The number of documents to pass to LiteLLM if a prompt
with the `["DOCUMENTS"]` tag is used.
diversity: The diversity of documents to pass to LiteLLM.
Accepts values between 0 and 1. A higher
values results in passing more diverse documents
whereas lower values passes more similar documents.
Usage:
To use this, you will need to install the litellm package first:
`pip install litellm`
Then, get yourself an API key of any provider (for instance OpenAI) and use it as follows:
```python
import os
from bertopic.representation import LiteLLM
from bertopic import BERTopic
# set ENV variables
os.environ["OPENAI_API_KEY"] = "your-openai-key"
# Create your representation model
representation_model = LiteLLM(model="gpt-3.5-turbo")
# Use the representation model in BERTopic on top of the default pipeline
topic_model = BERTopic(representation_model=representation_model)
```
You can also use a custom prompt:
```python
prompt = "I have the following documents: [DOCUMENTS] \nThese documents are about the following topic: '"
representation_model = LiteLLM(model="gpt", prompt=prompt)
```
""" # noqa: D301
def __init__(
self,
model: str = "gpt-3.5-turbo",
prompt: str = None,
generator_kwargs: Mapping[str, Any] = {},
delay_in_seconds: float = None,
exponential_backoff: bool = False,
nr_docs: int = 4,
diversity: float = None,
):
self.model = model
self.prompt = prompt if prompt else DEFAULT_PROMPT
self.default_prompt_ = DEFAULT_PROMPT
self.delay_in_seconds = delay_in_seconds
self.exponential_backoff = exponential_backoff
self.nr_docs = nr_docs
self.diversity = diversity
self.generator_kwargs = generator_kwargs
if self.generator_kwargs.get("model"):
self.model = generator_kwargs.get("model")
if self.generator_kwargs.get("prompt"):
del self.generator_kwargs["prompt"]
def extract_topics(
self, topic_model, documents: pd.DataFrame, c_tf_idf: csr_matrix, topics: Mapping[str, List[Tuple[str, float]]]
) -> Mapping[str, List[Tuple[str, float]]]:
"""Extract topics.
Arguments:
topic_model: A BERTopic model
documents: All input documents
c_tf_idf: The topic c-TF-IDF representation
topics: The candidate topics as calculated with c-TF-IDF
Returns:
updated_topics: Updated topic representations
"""
# Extract the top n representative documents per topic
repr_docs_mappings, _, _, _ = topic_model._extract_representative_docs(
c_tf_idf, documents, topics, 500, self.nr_docs, self.diversity
)
# Generate using a (Large) Language Model
updated_topics = {}
for topic, docs in repr_docs_mappings.items():
prompt = self._create_prompt(docs, topic, topics)
# Delay
if self.delay_in_seconds:
time.sleep(self.delay_in_seconds)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt},
]
kwargs = {"model": self.model, "messages": messages, **self.generator_kwargs}
if self.exponential_backoff:
response = chat_completions_with_backoff(**kwargs)
else:
response = completion(**kwargs)
label = response["choices"][0]["message"]["content"].strip().replace("topic: ", "")
updated_topics[topic] = [(label, 1)]
return updated_topics
def _create_prompt(self, docs, topic, topics):
keywords = list(zip(*topics[topic]))[0]
# Use the Default Chat Prompt
if self.prompt == DEFAULT_PROMPT:
prompt = self.prompt.replace("[KEYWORDS]", " ".join(keywords))
prompt = self._replace_documents(prompt, docs)
# Use a custom prompt that leverages keywords, documents or both using
# custom tags, namely [KEYWORDS] and [DOCUMENTS] respectively
else:
prompt = self.prompt
if "[KEYWORDS]" in prompt:
prompt = prompt.replace("[KEYWORDS]", " ".join(keywords))
if "[DOCUMENTS]" in prompt:
prompt = self._replace_documents(prompt, docs)
return prompt
@staticmethod
def _replace_documents(prompt, docs):
to_replace = ""
for doc in docs:
to_replace += f"- {doc[:255]}\n"
prompt = prompt.replace("[DOCUMENTS]", to_replace)
return prompt
def chat_completions_with_backoff(**kwargs):
return retry_with_exponential_backoff(
completion,
)(**kwargs)
@@ -0,0 +1,215 @@
import pandas as pd
from tqdm import tqdm
from scipy.sparse import csr_matrix
from llama_cpp import Llama
from typing import Mapping, List, Tuple, Any, Union, Callable
from bertopic.representation._base import BaseRepresentation
from bertopic.representation._utils import truncate_document, validate_truncate_document_parameters
DEFAULT_PROMPT = """
This is a list of texts where each collection of texts describe a topic. After each collection of texts, the name of the topic they represent is mentioned as a short-highly-descriptive title
---
Topic:
Sample texts from this topic:
- Traditional diets in most cultures were primarily plant-based with a little meat on top, but with the rise of industrial style meat production and factory farming, meat has become a staple food.
- Meat, but especially beef, is the word food in terms of emissions.
- Eating meat doesn't make you a bad person, not eating meat doesn't make you a good one.
Keywords: meat beef eat eating emissions steak food health processed chicken
Topic name: Environmental impacts of eating meat
---
Topic:
Sample texts from this topic:
- I have ordered the product weeks ago but it still has not arrived!
- The website mentions that it only takes a couple of days to deliver but I still have not received mine.
- I got a message stating that I received the monitor but that is not true!
- It took a month longer to deliver than was advised...
Keywords: deliver weeks product shipping long delivery received arrived arrive week
Topic name: Shipping and delivery issues
---
Topic:
Sample texts from this topic:
[DOCUMENTS]
Keywords: [KEYWORDS]
Topic name:"""
DEFAULT_SYSTEM_PROMPT = "You are an assistant that extracts high-level topics from texts."
class LlamaCPP(BaseRepresentation):
"""A llama.cpp implementation to use as a representation model.
Arguments:
model: Either a string pointing towards a local LLM or a
`llama_cpp.Llama` object.
prompt: The prompt to be used in the model. If no prompt is given,
`self.default_prompt_` is used instead.
NOTE: Use `"[KEYWORDS]"` and `"[DOCUMENTS]"` in the prompt
to decide where the keywords and documents need to be
inserted.
system_prompt: The system prompt to be used in the model. If no system prompt is given,
`self.default_system_prompt_` is used instead.
pipeline_kwargs: Kwargs that you can pass to the `llama_cpp.Llama`
when it is called such as `max_tokens` to be generated.
nr_docs: The number of documents to pass to OpenAI if a prompt
with the `["DOCUMENTS"]` tag is used.
diversity: The diversity of documents to pass to OpenAI.
Accepts values between 0 and 1. A higher
values results in passing more diverse documents
whereas lower values passes more similar documents.
doc_length: The maximum length of each document. If a document is longer,
it will be truncated. If None, the entire document is passed.
tokenizer: The tokenizer used to calculate to split the document into segments
used to count the length of a document.
* If tokenizer is 'char', then the document is split up
into characters which are counted to adhere to `doc_length`
* If tokenizer is 'whitespace', the the document is split up
into words separated by whitespaces. These words are counted
and truncated depending on `doc_length`
* If tokenizer is 'vectorizer', then the internal CountVectorizer
is used to tokenize the document. These tokens are counted
and truncated depending on `doc_length`
* If tokenizer is a callable, then that callable is used to tokenize
the document. These tokens are counted and truncated depending
on `doc_length`
Usage:
To use a llama.cpp, first download the LLM:
```bash
wget https://huggingface.co/TheBloke/zephyr-7B-alpha-GGUF/resolve/main/zephyr-7b-alpha.Q4_K_M.gguf
```
Then, we can now use the model the model with BERTopic in just a couple of lines:
```python
from bertopic import BERTopic
from bertopic.representation import LlamaCPP
# Use llama.cpp to load in a 4-bit quantized version of Zephyr 7B Alpha
representation_model = LlamaCPP("zephyr-7b-alpha.Q4_K_M.gguf")
# Create our BERTopic model
topic_model = BERTopic(representation_model=representation_model, verbose=True)
```
If you want to have more control over the LLMs parameters, you can run it like so:
```python
from bertopic import BERTopic
from bertopic.representation import LlamaCPP
from llama_cpp import Llama
# Use llama.cpp to load in a 4-bit quantized version of Zephyr 7B Alpha
llm = Llama(model_path="zephyr-7b-alpha.Q4_K_M.gguf", n_gpu_layers=-1, n_ctx=4096, stop="Q:")
representation_model = LlamaCPP(llm)
# Create our BERTopic model
topic_model = BERTopic(representation_model=representation_model, verbose=True)
```
"""
def __init__(
self,
model: Union[str, Llama],
prompt: str = None,
system_prompt: str = None,
pipeline_kwargs: Mapping[str, Any] = {},
nr_docs: int = 4,
diversity: float = None,
doc_length: int = None,
tokenizer: Union[str, Callable] = None,
):
if isinstance(model, str):
self.model = Llama(model_path=model, n_gpu_layers=-1, stop="\n", chat_format="ChatML")
elif isinstance(model, Llama):
self.model = model
else:
raise ValueError(
"Make sure that the model that you"
"pass is either a string referring to a"
"local LLM or a ` llama_cpp.Llama` object."
)
self.prompt = prompt if prompt is not None else DEFAULT_PROMPT
self.system_prompt = system_prompt if system_prompt is not None else DEFAULT_SYSTEM_PROMPT
self.default_prompt_ = DEFAULT_PROMPT
self.default_system_prompt_ = DEFAULT_SYSTEM_PROMPT
self.pipeline_kwargs = pipeline_kwargs
self.nr_docs = nr_docs
self.diversity = diversity
self.doc_length = doc_length
self.tokenizer = tokenizer
validate_truncate_document_parameters(self.tokenizer, self.doc_length)
self.prompts_ = []
def extract_topics(
self,
topic_model,
documents: pd.DataFrame,
c_tf_idf: csr_matrix,
topics: Mapping[str, List[Tuple[str, float]]],
) -> Mapping[str, List[Tuple[str, float]]]:
"""Extract topic representations and return a single label.
Arguments:
topic_model: A BERTopic model
documents: Not used
c_tf_idf: Not used
topics: The candidate topics as calculated with c-TF-IDF
Returns:
updated_topics: Updated topic representations
"""
# Extract the top 4 representative documents per topic
repr_docs_mappings, _, _, _ = topic_model._extract_representative_docs(
c_tf_idf, documents, topics, 500, self.nr_docs, self.diversity
)
updated_topics = {}
for topic, docs in tqdm(repr_docs_mappings.items(), disable=not topic_model.verbose):
# Prepare prompt
truncated_docs = [truncate_document(topic_model, self.doc_length, self.tokenizer, doc) for doc in docs]
prompt = self._create_prompt(truncated_docs, topic, topics)
self.prompts_.append(prompt)
# Extract result from generator and use that as label
# topic_description = self.model(prompt, **self.pipeline_kwargs)["choices"]
topic_description = self.model.create_chat_completion(
messages=[{"role": "system", "content": self.system_prompt}, {"role": "user", "content": prompt}],
**self.pipeline_kwargs,
)
label = topic_description["choices"][0]["message"]["content"].strip()
updated_topics[topic] = [(label, 1)] + [("", 0) for _ in range(9)]
return updated_topics
def _create_prompt(self, docs, topic, topics):
keywords = list(zip(*topics[topic]))[0]
# Use the Default Chat Prompt
if self.prompt == DEFAULT_PROMPT:
prompt = self.prompt.replace("[KEYWORDS]", ", ".join(keywords))
prompt = self._replace_documents(prompt, docs)
# Use a custom prompt that leverages keywords, documents or both using
# custom tags, namely [KEYWORDS] and [DOCUMENTS] respectively
else:
prompt = self.prompt
if "[KEYWORDS]" in prompt:
prompt = prompt.replace("[KEYWORDS]", ", ".join(keywords))
if "[DOCUMENTS]" in prompt:
prompt = self._replace_documents(prompt, docs)
return prompt
@staticmethod
def _replace_documents(prompt, docs):
to_replace = ""
for doc in docs:
to_replace += f"- {doc}\n"
prompt = prompt.replace("[DOCUMENTS]", to_replace)
return prompt
@@ -0,0 +1,128 @@
import warnings
import numpy as np
import pandas as pd
from typing import List, Mapping, Tuple
from scipy.sparse import csr_matrix
from sklearn.metrics.pairwise import cosine_similarity
from bertopic.representation._base import BaseRepresentation
class MaximalMarginalRelevance(BaseRepresentation):
"""Calculate Maximal Marginal Relevance (MMR)
between candidate keywords and the document.
MMR considers the similarity of keywords/keyphrases with the
document, along with the similarity of already selected
keywords and keyphrases. This results in a selection of keywords
that maximize their within diversity with respect to the document.
Arguments:
diversity: How diverse the select keywords/keyphrases are.
Values range between 0 and 1 with 0 being not diverse at all
and 1 being most diverse.
top_n_words: The number of keywords/keyhprases to return
Usage:
```python
from bertopic.representation import MaximalMarginalRelevance
from bertopic import BERTopic
# Create your representation model
representation_model = MaximalMarginalRelevance(diversity=0.3)
# Use the representation model in BERTopic on top of the default pipeline
topic_model = BERTopic(representation_model=representation_model)
```
"""
def __init__(self, diversity: float = 0.1, top_n_words: int = 10):
self.diversity = diversity
self.top_n_words = top_n_words
def extract_topics(
self,
topic_model,
documents: pd.DataFrame,
c_tf_idf: csr_matrix,
topics: Mapping[str, List[Tuple[str, float]]],
) -> Mapping[str, List[Tuple[str, float]]]:
"""Extract topic representations.
Arguments:
topic_model: The BERTopic model
documents: Not used
c_tf_idf: Not used
topics: The candidate topics as calculated with c-TF-IDF
Returns:
updated_topics: Updated topic representations
"""
if topic_model.embedding_model is None:
warnings.warn(
"MaximalMarginalRelevance can only be used BERTopic was instantiated"
"with the `embedding_model` parameter."
)
return topics
updated_topics = {}
for topic, topic_words in topics.items():
words = [word[0] for word in topic_words]
word_embeddings = topic_model._extract_embeddings(words, method="word", verbose=False)
topic_embedding = topic_model._extract_embeddings(" ".join(words), method="word", verbose=False).reshape(
1, -1
)
topic_words = mmr(
topic_embedding,
word_embeddings,
words,
self.diversity,
self.top_n_words,
)
updated_topics[topic] = [(word, value) for word, value in topics[topic] if word in topic_words]
return updated_topics
def mmr(
doc_embedding: np.ndarray,
word_embeddings: np.ndarray,
words: List[str],
diversity: float = 0.1,
top_n: int = 10,
) -> List[str]:
"""Maximal Marginal Relevance.
Arguments:
doc_embedding: The document embeddings
word_embeddings: The embeddings of the selected candidate keywords/phrases
words: The selected candidate keywords/keyphrases
diversity: The diversity of the selected embeddings.
Values between 0 and 1.
top_n: The top n items to return
Returns:
List[str]: The selected keywords/keyphrases
"""
# Extract similarity within words, and between words and the document
word_doc_similarity = cosine_similarity(word_embeddings, doc_embedding)
word_similarity = cosine_similarity(word_embeddings)
# Initialize candidates and already choose best keyword/keyphras
keywords_idx = [np.argmax(word_doc_similarity)]
candidates_idx = [i for i in range(len(words)) if i != keywords_idx[0]]
for _ in range(top_n - 1):
# Extract similarities within candidates and
# between candidates and selected keywords/phrases
candidate_similarities = word_doc_similarity[candidates_idx, :]
target_similarities = np.max(word_similarity[candidates_idx][:, keywords_idx], axis=1)
# Calculate MMR
mmr = (1 - diversity) * candidate_similarities - diversity * target_similarities.reshape(-1, 1)
mmr_idx = candidates_idx[np.argmax(mmr)]
# Update keywords & candidates
keywords_idx.append(mmr_idx)
candidates_idx.remove(mmr_idx)
return [words[idx] for idx in keywords_idx]
@@ -0,0 +1,274 @@
import time
import openai
import pandas as pd
from tqdm import tqdm
from scipy.sparse import csr_matrix
from typing import Mapping, List, Tuple, Any, Union, Callable
from bertopic.representation._base import BaseRepresentation
from bertopic.representation._utils import (
retry_with_exponential_backoff,
truncate_document,
validate_truncate_document_parameters,
)
DEFAULT_CHAT_PROMPT = """You will extract a short topic label from given documents and keywords.
Here are two examples of topics you created before:
# Example 1
Sample texts from this topic:
- Traditional diets in most cultures were primarily plant-based with a little meat on top, but with the rise of industrial style meat production and factory farming, meat has become a staple food.
- Meat, but especially beef, is the worst food in terms of emissions.
- Eating meat doesn't make you a bad person, not eating meat doesn't make you a good one.
Keywords: meat beef eat eating emissions steak food health processed chicken
topic: Environmental impacts of eating meat
# Example 2
Sample texts from this topic:
- I have ordered the product weeks ago but it still has not arrived!
- The website mentions that it only takes a couple of days to deliver but I still have not received mine.
- I got a message stating that I received the monitor but that is not true!
- It took a month longer to deliver than was advised...
Keywords: deliver weeks product shipping long delivery received arrived arrive week
topic: Shipping and delivery issues
# Your task
Sample texts from this topic:
[DOCUMENTS]
Keywords: [KEYWORDS]
Based on the information above, extract a short topic label (three words at most) in the following format:
topic: <topic_label>
"""
DEFAULT_SYSTEM_PROMPT = "You are an assistant that extracts high-level topics from texts."
class OpenAI(BaseRepresentation):
r"""Using the OpenAI API to generate topic labels based
on one of their Completion of ChatCompletion models.
For an overview see:
https://platform.openai.com/docs/models
Arguments:
client: A `openai.OpenAI` client
model: Model to use within OpenAI, defaults to `"gpt-4o-mini"`.
generator_kwargs: Kwargs passed to `openai.Completion.create`
for fine-tuning the output.
prompt: The prompt to be used in the model. If no prompt is given,
`self.default_prompt_` is used instead.
NOTE: Use `"[KEYWORDS]"` and `"[DOCUMENTS]"` in the prompt
to decide where the keywords and documents need to be
inserted.
system_prompt: The system prompt to be used in the model. If no system prompt is given,
`self.default_system_prompt_` is used instead.
delay_in_seconds: The delay in seconds between consecutive prompts
in order to prevent RateLimitErrors.
exponential_backoff: Retry requests with a random exponential backoff.
A short sleep is used when a rate limit error is hit,
then the requests is retried. Increase the sleep length
if errors are hit until 10 unsuccessful requests.
If True, overrides `delay_in_seconds`.
nr_docs: The number of documents to pass to OpenAI if a prompt
with the `["DOCUMENTS"]` tag is used.
diversity: The diversity of documents to pass to OpenAI.
Accepts values between 0 and 1. A higher
values results in passing more diverse documents
whereas lower values passes more similar documents.
doc_length: The maximum length of each document. If a document is longer,
it will be truncated. If None, the entire document is passed.
tokenizer: The tokenizer used to calculate to split the document into segments
used to count the length of a document.
* If tokenizer is 'char', then the document is split up
into characters which are counted to adhere to `doc_length`
* If tokenizer is 'whitespace', the document is split up
into words separated by whitespaces. These words are counted
and truncated depending on `doc_length`
* If tokenizer is 'vectorizer', then the internal CountVectorizer
is used to tokenize the document. These tokens are counted
and truncated depending on `doc_length`
* If tokenizer is a callable, then that callable is used to tokenize
the document. These tokens are counted and truncated depending
on `doc_length`
Usage:
To use this, you will need to install the openai package first:
`pip install openai`
Then, get yourself an API key and use OpenAI's API as follows:
```python
import openai
from bertopic.representation import OpenAI
from bertopic import BERTopic
# Create your representation model
client = openai.OpenAI(api_key=MY_API_KEY)
representation_model = OpenAI(client, delay_in_seconds=5)
# Use the representation model in BERTopic on top of the default pipeline
topic_model = BERTopic(representation_model=representation_model)
```
You can also use a custom prompt:
```python
prompt = "I have the following documents: [DOCUMENTS] \nThese documents are about the following topic: '"
representation_model = OpenAI(client, prompt=prompt, delay_in_seconds=5)
```
To choose a model:
```python
representation_model = OpenAI(client, model="gpt-4o-mini", delay_in_seconds=10)
```
"""
def __init__(
self,
client,
model: str = "gpt-4o-mini",
prompt: str = None,
system_prompt: str = None,
generator_kwargs: Mapping[str, Any] = {},
delay_in_seconds: float = None,
exponential_backoff: bool = False,
nr_docs: int = 4,
diversity: float = None,
doc_length: int = None,
tokenizer: Union[str, Callable] = None,
**kwargs,
):
self.client = client
self.model = model
if prompt is None:
self.prompt = DEFAULT_CHAT_PROMPT
else:
self.prompt = prompt
if system_prompt is None:
self.system_prompt = DEFAULT_SYSTEM_PROMPT
else:
self.system_prompt = system_prompt
self.default_prompt_ = DEFAULT_CHAT_PROMPT
self.default_system_prompt_ = DEFAULT_SYSTEM_PROMPT
self.delay_in_seconds = delay_in_seconds
self.exponential_backoff = exponential_backoff
self.nr_docs = nr_docs
self.diversity = diversity
self.doc_length = doc_length
self.tokenizer = tokenizer
validate_truncate_document_parameters(self.tokenizer, self.doc_length)
self.prompts_ = []
self.generator_kwargs = generator_kwargs
if self.generator_kwargs.get("model"):
self.model = generator_kwargs.get("model")
del self.generator_kwargs["model"]
if self.generator_kwargs.get("prompt"):
del self.generator_kwargs["prompt"]
if not self.generator_kwargs.get("stop"):
self.generator_kwargs["stop"] = "\n"
def extract_topics(
self,
topic_model,
documents: pd.DataFrame,
c_tf_idf: csr_matrix,
topics: Mapping[str, List[Tuple[str, float]]],
) -> Mapping[str, List[Tuple[str, float]]]:
"""Extract topics.
Arguments:
topic_model: A BERTopic model
documents: All input documents
c_tf_idf: The topic c-TF-IDF representation
topics: The candidate topics as calculated with c-TF-IDF
Returns:
updated_topics: Updated topic representations
"""
# Extract the top n representative documents per topic
repr_docs_mappings, _, _, _ = topic_model._extract_representative_docs(
c_tf_idf, documents, topics, 500, self.nr_docs, self.diversity
)
# Generate using OpenAI's Language Model
updated_topics = {}
for topic, docs in tqdm(repr_docs_mappings.items(), disable=not topic_model.verbose):
truncated_docs = [truncate_document(topic_model, self.doc_length, self.tokenizer, doc) for doc in docs]
prompt = self._create_prompt(truncated_docs, topic, topics)
self.prompts_.append(prompt)
# Delay
if self.delay_in_seconds:
time.sleep(self.delay_in_seconds)
messages = [
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": prompt},
]
kwargs = {
"model": self.model,
"messages": messages,
**self.generator_kwargs,
}
if self.exponential_backoff:
response = chat_completions_with_backoff(self.client, **kwargs)
else:
response = self.client.chat.completions.create(**kwargs)
# Check whether content was actually generated
# Addresses #1570 for potential issues with OpenAI's content filter
# Addresses #2176 for potential issues when openAI returns a None type object
if response and hasattr(response.choices[0].message, "content"):
label = response.choices[0].message.content.strip().replace("topic: ", "")
else:
label = "No label returned"
updated_topics[topic] = [(label, 1)]
return updated_topics
def _create_prompt(self, docs, topic, topics):
keywords = list(zip(*topics[topic]))[0]
# Use the Default Chat Prompt
if self.prompt == DEFAULT_CHAT_PROMPT:
prompt = self.prompt.replace("[KEYWORDS]", ", ".join(keywords))
prompt = self._replace_documents(prompt, docs)
# Use a custom prompt that leverages keywords, documents or both using
# custom tags, namely [KEYWORDS] and [DOCUMENTS] respectively
else:
prompt = self.prompt
if "[KEYWORDS]" in prompt:
prompt = prompt.replace("[KEYWORDS]", ", ".join(keywords))
if "[DOCUMENTS]" in prompt:
prompt = self._replace_documents(prompt, docs)
return prompt
@staticmethod
def _replace_documents(prompt, docs):
to_replace = ""
for doc in docs:
to_replace += f"- {doc}\n"
prompt = prompt.replace("[DOCUMENTS]", to_replace)
return prompt
def chat_completions_with_backoff(client, **kwargs):
return retry_with_exponential_backoff(
client.chat.completions.create,
errors=(openai.RateLimitError,),
)(**kwargs)
@@ -0,0 +1,161 @@
import numpy as np
import pandas as pd
import spacy
from spacy.matcher import Matcher
from spacy.language import Language
from packaging import version
from scipy.sparse import csr_matrix
from typing import List, Mapping, Tuple, Union
from sklearn import __version__ as sklearn_version
from bertopic.representation._base import BaseRepresentation
class PartOfSpeech(BaseRepresentation):
"""Extract Topic Keywords based on their Part-of-Speech.
DEFAULT_PATTERNS = [
[{'POS': 'ADJ'}, {'POS': 'NOUN'}],
[{'POS': 'NOUN'}],
[{'POS': 'ADJ'}]
]
From candidate topics, as extracted with c-TF-IDF,
find documents that contain keywords found in the
candidate topics. These candidate documents then
serve as the representative set of documents from
which the Spacy model can extract a set of candidate
keywords for each topic.
These candidate keywords are first judged by whether
they fall within the DEFAULT_PATTERNS or the user-defined
pattern. Then, the resulting keywords are sorted by
their respective c-TF-IDF values.
Arguments:
model: The Spacy model to use
top_n_words: The top n words to extract
pos_patterns: Patterns for Spacy to use.
See https://spacy.io/usage/rule-based-matching
Usage:
```python
from bertopic.representation import PartOfSpeech
from bertopic import BERTopic
# Create your representation model
representation_model = PartOfSpeech("en_core_web_sm")
# Use the representation model in BERTopic on top of the default pipeline
topic_model = BERTopic(representation_model=representation_model)
```
You can define custom POS patterns to be extracted:
```python
pos_patterns = [
[{'POS': 'ADJ'}, {'POS': 'NOUN'}],
[{'POS': 'NOUN'}], [{'POS': 'ADJ'}]
]
representation_model = PartOfSpeech("en_core_web_sm", pos_patterns=pos_patterns)
```
"""
def __init__(
self,
model: Union[str, Language] = "en_core_web_sm",
top_n_words: int = 10,
pos_patterns: List[str] = None,
):
if isinstance(model, str):
self.model = spacy.load(model)
elif isinstance(model, Language):
self.model = model
else:
raise ValueError(
"Make sure that the Spacy model that you"
"pass is either a string referring to a"
"Spacy model or a Spacy nlp object."
)
self.top_n_words = top_n_words
if pos_patterns is None:
self.pos_patterns = [
[{"POS": "ADJ"}, {"POS": "NOUN"}],
[{"POS": "NOUN"}],
[{"POS": "ADJ"}],
]
else:
self.pos_patterns = pos_patterns
def extract_topics(
self,
topic_model,
documents: pd.DataFrame,
c_tf_idf: csr_matrix,
topics: Mapping[str, List[Tuple[str, float]]],
) -> Mapping[str, List[Tuple[str, float]]]:
"""Extract topics.
Arguments:
topic_model: A BERTopic model
documents: All input documents
c_tf_idf: Not used
topics: The candidate topics as calculated with c-TF-IDF
Returns:
updated_topics: Updated topic representations
"""
matcher = Matcher(self.model.vocab)
matcher.add("Pattern", self.pos_patterns)
candidate_topics = {}
for topic, values in topics.items():
keywords = list(zip(*values))[0]
# Extract candidate documents
candidate_documents = []
for keyword in keywords:
selection = documents.loc[documents.Topic == topic, :]
selection = selection.loc[selection.Document.str.contains(keyword, regex=False), "Document"]
if len(selection) > 0:
for document in selection[:2]:
candidate_documents.append(document)
candidate_documents = list(set(candidate_documents))
# Extract keywords
docs_pipeline = self.model.pipe(candidate_documents)
updated_keywords = []
for doc in docs_pipeline:
matches = matcher(doc)
for _, start, end in matches:
updated_keywords.append(doc[start:end].text)
candidate_topics[topic] = list(set(updated_keywords))
# Scikit-Learn Deprecation: get_feature_names is deprecated in 1.0
# and will be removed in 1.2. Please use get_feature_names_out instead.
if version.parse(sklearn_version) >= version.parse("1.0.0"):
words = list(topic_model.vectorizer_model.get_feature_names_out())
else:
words = list(topic_model.vectorizer_model.get_feature_names())
# Match updated keywords with c-TF-IDF values
words_lookup = dict(zip(words, range(len(words))))
updated_topics = {topic: [] for topic in topics.keys()}
for topic, candidate_keywords in candidate_topics.items():
word_indices = np.sort(
[words_lookup.get(keyword) for keyword in candidate_keywords if keyword in words_lookup]
)
vals = topic_model.c_tf_idf_[:, word_indices][topic + topic_model._outliers]
indices = np.argsort(np.array(vals.todense().reshape(1, -1))[0])[-self.top_n_words :][::-1]
vals = np.sort(np.array(vals.todense().reshape(1, -1))[0])[-self.top_n_words :][::-1]
topic_words = [(words[word_indices[index]], val) for index, val in zip(indices, vals)]
updated_topics[topic] = topic_words
if len(updated_topics[topic]) < self.top_n_words:
updated_topics[topic] += [("", 0) for _ in range(self.top_n_words - len(updated_topics[topic]))]
return updated_topics
@@ -0,0 +1,188 @@
import pandas as pd
from tqdm import tqdm
from scipy.sparse import csr_matrix
from transformers import pipeline, set_seed
from transformers.pipelines.base import Pipeline
from typing import Mapping, List, Tuple, Any, Union, Callable
from bertopic.representation._base import BaseRepresentation
from bertopic.representation._utils import truncate_document, validate_truncate_document_parameters
DEFAULT_PROMPT = """
I have a topic described by the following keywords: [KEYWORDS].
The name of this topic is:
"""
class TextGeneration(BaseRepresentation):
"""Text2Text or text generation with transformers.
Arguments:
model: A transformers pipeline that should be initialized as "text-generation"
for gpt-like models or "text2text-generation" for T5-like models.
For example, `pipeline('text-generation', model='gpt2')`. If a string
is passed, "text-generation" will be selected by default.
prompt: The prompt to be used in the model. If no prompt is given,
`self.default_prompt_` is used instead.
NOTE: Use `"[KEYWORDS]"` and `"[DOCUMENTS]"` in the prompt
to decide where the keywords and documents need to be
inserted.
pipeline_kwargs: Kwargs that you can pass to the transformers.pipeline
when it is called.
random_state: A random state to be passed to `transformers.set_seed`
nr_docs: The number of documents to pass to OpenAI if a prompt
with the `["DOCUMENTS"]` tag is used.
diversity: The diversity of documents to pass to OpenAI.
Accepts values between 0 and 1. A higher
values results in passing more diverse documents
whereas lower values passes more similar documents.
doc_length: The maximum length of each document. If a document is longer,
it will be truncated. If None, the entire document is passed.
tokenizer: The tokenizer used to calculate to split the document into segments
used to count the length of a document.
* If tokenizer is 'char', then the document is split up
into characters which are counted to adhere to `doc_length`
* If tokenizer is 'whitespace', the document is split up
into words separated by whitespaces. These words are counted
and truncated depending on `doc_length`
* If tokenizer is 'vectorizer', then the internal CountVectorizer
is used to tokenize the document. These tokens are counted
and truncated depending on `doc_length`
* If tokenizer is a callable, then that callable is used to tokenize
the document. These tokens are counted and truncated depending
on `doc_length`
Usage:
To use a gpt-like model:
```python
from bertopic.representation import TextGeneration
from bertopic import BERTopic
# Create your representation model
generator = pipeline('text-generation', model='gpt2')
representation_model = TextGeneration(generator)
# Use the representation model in BERTopic on top of the default pipeline
topic_model = BERTo pic(representation_model=representation_model)
```
You can use a custom prompt and decide where the keywords should
be inserted by using the `[KEYWORDS]` or documents with thte `[DOCUMENTS]` tag:
```python
from bertopic.representation import TextGeneration
prompt = "I have a topic described by the following keywords: [KEYWORDS]. Based on the previous keywords, what is this topic about?""
# Create your representation model
generator = pipeline('text2text-generation', model='google/flan-t5-base')
representation_model = TextGeneration(generator)
```
"""
def __init__(
self,
model: Union[str, pipeline],
prompt: str = None,
pipeline_kwargs: Mapping[str, Any] = {},
random_state: int = 42,
nr_docs: int = 4,
diversity: float = None,
doc_length: int = None,
tokenizer: Union[str, Callable] = None,
):
self.random_state = random_state
set_seed(random_state)
if isinstance(model, str):
self.model = pipeline("text-generation", model=model)
elif isinstance(model, Pipeline):
self.model = model
else:
raise ValueError(
"Make sure that the HF model that you"
"pass is either a string referring to a"
"HF model or a `transformers.pipeline` object."
)
self.prompt = prompt if prompt is not None else DEFAULT_PROMPT
self.default_prompt_ = DEFAULT_PROMPT
self.pipeline_kwargs = pipeline_kwargs
self.nr_docs = nr_docs
self.diversity = diversity
self.doc_length = doc_length
self.tokenizer = tokenizer
validate_truncate_document_parameters(self.tokenizer, self.doc_length)
self.prompts_ = []
def extract_topics(
self,
topic_model,
documents: pd.DataFrame,
c_tf_idf: csr_matrix,
topics: Mapping[str, List[Tuple[str, float]]],
) -> Mapping[str, List[Tuple[str, float]]]:
"""Extract topic representations and return a single label.
Arguments:
topic_model: A BERTopic model
documents: Not used
c_tf_idf: Not used
topics: The candidate topics as calculated with c-TF-IDF
Returns:
updated_topics: Updated topic representations
"""
# Extract the top 4 representative documents per topic
if self.prompt != DEFAULT_PROMPT and "[DOCUMENTS]" in self.prompt:
repr_docs_mappings, _, _, _ = topic_model._extract_representative_docs(
c_tf_idf, documents, topics, 500, self.nr_docs, self.diversity
)
else:
repr_docs_mappings = {topic: None for topic in topics.keys()}
updated_topics = {}
for topic, docs in tqdm(repr_docs_mappings.items(), disable=not topic_model.verbose):
# Prepare prompt
truncated_docs = (
[truncate_document(topic_model, self.doc_length, self.tokenizer, doc) for doc in docs]
if docs is not None
else docs
)
prompt = self._create_prompt(truncated_docs, topic, topics)
self.prompts_.append(prompt)
# Extract result from generator and use that as label
topic_description = self.model(prompt, **self.pipeline_kwargs)
topic_description = [
(description["generated_text"].replace(prompt, ""), 1) for description in topic_description
]
if len(topic_description) < 10:
topic_description += [("", 0) for _ in range(10 - len(topic_description))]
updated_topics[topic] = topic_description
return updated_topics
def _create_prompt(self, docs, topic, topics):
keywords = ", ".join(list(zip(*topics[topic]))[0])
# Use the default prompt and replace keywords
if self.prompt == DEFAULT_PROMPT:
prompt = self.prompt.replace("[KEYWORDS]", keywords)
# Use a prompt that leverages either keywords or documents in
# a custom location
else:
prompt = self.prompt
if "[KEYWORDS]" in prompt:
prompt = prompt.replace("[KEYWORDS]", keywords)
if "[DOCUMENTS]" in prompt:
to_replace = ""
for doc in docs:
to_replace += f"- {doc}\n"
prompt = prompt.replace("[DOCUMENTS]", to_replace)
return prompt
@@ -0,0 +1,113 @@
import random
import time
from typing import Union
def truncate_document(topic_model, doc_length: Union[int, None], tokenizer: Union[str, callable], document: str) -> str:
"""Truncate a document to a certain length.
If you want to add a custom tokenizer, then it will need to have a `decode` and
`encode` method. An example would be the following custom tokenizer:
```python
class Tokenizer:
'A custom tokenizer that splits on commas'
def encode(self, doc):
return doc.split(",")
def decode(self, doc_chunks):
return ",".join(doc_chunks)
```
You can use this tokenizer by passing it to the `tokenizer` parameter.
Arguments:
topic_model: A BERTopic model
doc_length: The maximum length of each document. If a document is longer,
it will be truncated. If None, the entire document is passed.
tokenizer: The tokenizer used to calculate to split the document into segments
used to count the length of a document.
* If tokenizer is 'char', then the document is split up
into characters which are counted to adhere to `doc_length`
* If tokenizer is 'whitespace', the document is split up
into words separated by whitespaces. These words are counted
and truncated depending on `doc_length`
* If tokenizer is 'vectorizer', then the internal CountVectorizer
is used to tokenize the document. These tokens are counted
and truncated depending on `doc_length`. They are decoded with
whitespaces.
* If tokenizer is a callable, then that callable is used to tokenize
the document. These tokens are counted and truncated depending
on `doc_length`
document: A single document
Returns:
truncated_document: A truncated document
"""
if doc_length is not None:
if tokenizer == "char":
truncated_document = document[:doc_length]
elif tokenizer == "whitespace":
truncated_document = " ".join(document.split()[:doc_length])
elif tokenizer == "vectorizer":
tokenizer = topic_model.vectorizer_model.build_tokenizer()
truncated_document = " ".join(tokenizer(document)[:doc_length])
elif hasattr(tokenizer, "encode") and hasattr(tokenizer, "decode"):
encoded_document = tokenizer.encode(document)
truncated_document = tokenizer.decode(encoded_document[:doc_length])
return truncated_document
return document
def validate_truncate_document_parameters(tokenizer, doc_length) -> Union[None, ValueError]:
"""Validates parameters that are used in the function `truncate_document`."""
if tokenizer is None and doc_length is not None:
raise ValueError(
"Please select from one of the valid options for the `tokenizer` parameter: \n"
"{'char', 'whitespace', 'vectorizer'} \n"
"If `tokenizer` is of type callable ensure it has methods to encode and decode a document \n"
)
elif tokenizer is not None and doc_length is None:
raise ValueError("If `tokenizer` is provided, `doc_length` of type int must be provided as well.")
def retry_with_exponential_backoff(
func,
initial_delay: float = 1,
exponential_base: float = 2,
jitter: bool = True,
max_retries: int = 10,
errors: tuple = None,
):
"""Retry a function with exponential backoff."""
def wrapper(*args, **kwargs):
# Initialize variables
num_retries = 0
delay = initial_delay
# Loop until a successful response or max_retries is hit or an exception is raised
while True:
try:
return func(*args, **kwargs)
# Retry on specific errors
except errors:
# Increment retries
num_retries += 1
# Check if max retries has been reached
if num_retries > max_retries:
raise Exception(f"Maximum number of retries ({max_retries}) exceeded.")
# Increment the delay
delay *= exponential_base * (1 + jitter * random.random())
# Sleep for the delay
time.sleep(delay)
# Raise exceptions for any errors not specified
except Exception as e:
raise e
return wrapper
@@ -0,0 +1,274 @@
import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm
from scipy.sparse import csr_matrix
from typing import Mapping, List, Tuple, Union
from transformers.pipelines import Pipeline, pipeline
from bertopic.representation._mmr import mmr
from bertopic.representation._base import BaseRepresentation
class VisualRepresentation(BaseRepresentation):
"""From a collection of representative documents, extract
images to represent topics. These topics are represented by a
collage of images.
Arguments:
nr_repr_images: Number of representative images to extract
nr_samples: The number of candidate documents to extract per cluster.
image_height: The height of the resulting collage
image_square: Whether to resize each image in the collage
to a square. This can be visually more appealing
if all input images are all almost squares.
image_to_text_model: The model to caption images.
batch_size: The number of images to pass to the
`image_to_text_model`.
Usage:
```python
from bertopic.representation import VisualRepresentation
from bertopic import BERTopic
# The visual representation is typically not a core representation
# and is advised to pass to BERTopic as an additional aspect.
# Aspects can be labeled with dictionaries as shown below:
representation_model = {
"Visual_Aspect": VisualRepresentation()
}
# Use the representation model in BERTopic as a separate aspect
topic_model = BERTopic(representation_model=representation_model)
```
"""
def __init__(
self,
nr_repr_images: int = 9,
nr_samples: int = 500,
image_height: Tuple[int, int] = 600,
image_squares: bool = False,
image_to_text_model: Union[str, Pipeline] = None,
batch_size: int = 32,
):
self.nr_repr_images = nr_repr_images
self.nr_samples = nr_samples
self.image_height = image_height
self.image_squares = image_squares
# Text-to-image model
if isinstance(image_to_text_model, Pipeline):
self.image_to_text_model = image_to_text_model
elif isinstance(image_to_text_model, str):
self.image_to_text_model = pipeline("image-to-text", model=image_to_text_model)
elif image_to_text_model is None:
self.image_to_text_model = None
else:
raise ValueError(
"Please select a correct transformers pipeline. For example:"
"pipeline('image-to-text', model='nlpconnect/vit-gpt2-image-captioning')"
)
self.batch_size = batch_size
def extract_topics(
self,
topic_model,
documents: pd.DataFrame,
c_tf_idf: csr_matrix,
topics: Mapping[str, List[Tuple[str, float]]],
) -> Mapping[str, List[Tuple[str, float]]]:
"""Extract topics.
Arguments:
topic_model: A BERTopic model
documents: All input documents
c_tf_idf: The topic c-TF-IDF representation
topics: The candidate topics as calculated with c-TF-IDF
Returns:
representative_images: Representative images per topic
"""
# Extract image ids of most representative documents
images = documents["Image"].values.tolist()
(_, _, _, repr_docs_ids) = topic_model._extract_representative_docs(
c_tf_idf,
documents,
topics,
nr_samples=self.nr_samples,
nr_repr_docs=self.nr_repr_images,
)
unique_topics = sorted(list(topics.keys()))
# Combine representative images into a single representation
representative_images = {}
for topic in tqdm(unique_topics):
# Get and order represetnative images
sliced_examplars = repr_docs_ids[topic + topic_model._outliers]
sliced_examplars = [sliced_examplars[i : i + 3] for i in range(0, len(sliced_examplars), 3)]
images_to_combine = [
[
Image.open(images[index]) if isinstance(images[index], str) else images[index]
for index in sub_indices
]
for sub_indices in sliced_examplars
]
# Concatenate representative images
representative_image = get_concat_tile_resize(images_to_combine, self.image_height, self.image_squares)
representative_images[topic] = representative_image
# Make sure to properly close images
if isinstance(images[0], str):
for image_list in images_to_combine:
for image in image_list:
image.close()
return representative_images
def _convert_image_to_text(self, images: List[str], verbose: bool = False) -> List[str]:
"""Convert a list of images to captions.
Arguments:
images: A list of images or words to be converted to text.
verbose: Controls the verbosity of the process
Returns:
List of captions
"""
# Batch-wise image conversion
if self.batch_size is not None:
documents = []
for batch in tqdm(self._chunks(images), disable=not verbose):
outputs = self.image_to_text_model(batch)
captions = [output[0]["generated_text"] for output in outputs]
documents.extend(captions)
# Convert images to text
else:
outputs = self.image_to_text_model(images)
documents = [output[0]["generated_text"] for output in outputs]
return documents
def image_to_text(self, documents: pd.DataFrame, embeddings: np.ndarray) -> pd.DataFrame:
"""Convert images to text."""
# Create image topic embeddings
topics = documents.Topic.values.tolist()
images = documents.Image.values.tolist()
df = pd.DataFrame(np.hstack([np.array(topics).reshape(-1, 1), embeddings]))
image_topic_embeddings = df.groupby(0).mean().values
# Extract image centroids
image_centroids = {}
unique_topics = sorted(list(set(topics)))
for topic, topic_embedding in zip(unique_topics, image_topic_embeddings):
indices = np.array([index for index, t in enumerate(topics) if t == topic])
top_n = min([self.nr_repr_images, len(indices)])
indices = mmr(
topic_embedding.reshape(1, -1),
embeddings[indices],
indices,
top_n=top_n,
diversity=0.1,
)
image_centroids[topic] = indices
# Extract documents
documents = pd.DataFrame(columns=["Document", "ID", "Topic", "Image"])
current_id = 0
for topic, image_ids in tqdm(image_centroids.items()):
selected_images = [
Image.open(images[index]) if isinstance(images[index], str) else images[index] for index in image_ids
]
text = self._convert_image_to_text(selected_images)
for doc, image_id in zip(text, image_ids):
documents.loc[len(documents), :] = [
doc,
current_id,
topic,
images[image_id],
]
current_id += 1
# Properly close images
if isinstance(images[image_ids[0]], str):
for image in selected_images:
image.close()
return documents
def _chunks(self, images):
for i in range(0, len(images), self.batch_size):
yield images[i : i + self.batch_size]
def get_concat_h_multi_resize(im_list):
"""Code adapted from: https://note.nkmk.me/en/python-pillow-concat-images/."""
min_height = min(im.height for im in im_list)
min_height = max(im.height for im in im_list)
im_list_resize = []
for im in im_list:
im.resize((int(im.width * min_height / im.height), min_height), resample=0)
im_list_resize.append(im)
total_width = sum(im.width for im in im_list_resize)
dst = Image.new("RGB", (total_width, min_height), (255, 255, 255))
pos_x = 0
for im in im_list_resize:
dst.paste(im, (pos_x, 0))
pos_x += im.width
return dst
def get_concat_v_multi_resize(im_list):
"""Code adapted from: https://note.nkmk.me/en/python-pillow-concat-images/."""
min_width = min(im.width for im in im_list)
min_width = max(im.width for im in im_list)
im_list_resize = [im.resize((min_width, int(im.height * min_width / im.width)), resample=0) for im in im_list]
total_height = sum(im.height for im in im_list_resize)
dst = Image.new("RGB", (min_width, total_height), (255, 255, 255))
pos_y = 0
for im in im_list_resize:
dst.paste(im, (0, pos_y))
pos_y += im.height
return dst
def get_concat_tile_resize(im_list_2d, image_height=600, image_squares=False):
"""Code adapted from: https://note.nkmk.me/en/python-pillow-concat-images/."""
images = [[image.copy() for image in images] for images in im_list_2d]
# Create
if image_squares:
width = int(image_height / 3)
height = int(image_height / 3)
images = [[image.resize((width, height)) for image in images] for images in im_list_2d]
# Resize images based on minimum size
else:
min_width = min([min([img.width for img in imgs]) for imgs in im_list_2d])
min_height = min([min([img.height for img in imgs]) for imgs in im_list_2d])
for i, imgs in enumerate(images):
for j, img in enumerate(imgs):
if img.height > img.width:
images[i][j] = img.resize(
(int(img.width * min_height / img.height), min_height),
resample=0,
)
elif img.width > img.height:
images[i][j] = img.resize((min_width, int(img.height * min_width / img.width)), resample=0)
else:
images[i][j] = img.resize((min_width, min_width))
# Resize grid image
images = [get_concat_h_multi_resize(im_list_h) for im_list_h in images]
img = get_concat_v_multi_resize(images)
height_percentage = image_height / float(img.size[1])
adjusted_width = int((float(img.size[0]) * float(height_percentage)))
img = img.resize((adjusted_width, image_height), Image.Resampling.LANCZOS)
return img
@@ -0,0 +1,104 @@
import pandas as pd
from transformers import pipeline
from transformers.pipelines.base import Pipeline
from scipy.sparse import csr_matrix
from typing import Mapping, List, Tuple, Any
from bertopic.representation._base import BaseRepresentation
class ZeroShotClassification(BaseRepresentation):
"""Zero-shot Classification on topic keywords with candidate labels.
Arguments:
candidate_topics: A list of labels to assign to the topics if they
exceed `min_prob`
model: A transformers pipeline that should be initialized as
"zero-shot-classification". For example,
`pipeline("zero-shot-classification", model="facebook/bart-large-mnli")`
pipeline_kwargs: Kwargs that you can pass to the transformers.pipeline
when it is called. NOTE: Use `{"multi_label": True}`
to extract multiple labels for each topic.
min_prob: The minimum probability to assign a candidate label to a topic
Usage:
```python
from bertopic.representation import ZeroShotClassification
from bertopic import BERTopic
# Create your representation model
candidate_topics = ["space and nasa", "bicycles", "sports"]
representation_model = ZeroShotClassification(candidate_topics, model="facebook/bart-large-mnli")
# Use the representation model in BERTopic on top of the default pipeline
topic_model = BERTopic(representation_model=representation_model)
```
"""
def __init__(
self,
candidate_topics: List[str],
model: str = "facebook/bart-large-mnli",
pipeline_kwargs: Mapping[str, Any] = {},
min_prob: float = 0.8,
):
self.candidate_topics = candidate_topics
if isinstance(model, str):
self.model = pipeline("zero-shot-classification", model=model)
elif isinstance(model, Pipeline):
self.model = model
else:
raise ValueError(
"Make sure that the HF model that you"
"pass is either a string referring to a"
"HF model or a `transformers.pipeline` object."
)
self.pipeline_kwargs = pipeline_kwargs
self.min_prob = min_prob
def extract_topics(
self,
topic_model,
documents: pd.DataFrame,
c_tf_idf: csr_matrix,
topics: Mapping[str, List[Tuple[str, float]]],
) -> Mapping[str, List[Tuple[str, float]]]:
"""Extract topics.
Arguments:
topic_model: Not used
documents: Not used
c_tf_idf: Not used
topics: The candidate topics as calculated with c-TF-IDF
Returns:
updated_topics: Updated topic representations
"""
# Classify topics
topic_descriptions = [" ".join(list(zip(*topics[topic]))[0]) for topic in topics.keys()]
classifications = self.model(topic_descriptions, self.candidate_topics, **self.pipeline_kwargs)
# Extract labels
updated_topics = {}
for topic, classification in zip(topics.keys(), classifications):
topic_description = topics[topic]
# Multi-label assignment
if self.pipeline_kwargs.get("multi_label"):
topic_description = []
for label, score in zip(classification["labels"], classification["scores"]):
if score > self.min_prob:
topic_description.append((label, score))
# Single label assignment
elif classification["scores"][0] > self.min_prob:
topic_description = [(classification["labels"][0], classification["scores"][0])]
# Make sure that 10 items are returned
if len(topic_description) == 0:
topic_description = topics[topic]
elif len(topic_description) < 10:
topic_description += [("", 0) for _ in range(10 - len(topic_description))]
updated_topics[topic] = topic_description
return updated_topics