Add BERTopic.
This commit is contained in:
@@ -0,0 +1,76 @@
|
||||
from bertopic._utils import NotInstalled
|
||||
from bertopic.representation._cohere import Cohere
|
||||
from bertopic.representation._base import BaseRepresentation
|
||||
from bertopic.representation._keybert import KeyBERTInspired
|
||||
from bertopic.representation._mmr import MaximalMarginalRelevance
|
||||
|
||||
|
||||
# Llama CPP Generator
|
||||
try:
|
||||
from bertopic.representation._llamacpp import LlamaCPP
|
||||
except ModuleNotFoundError:
|
||||
msg = "`pip install llama-cpp-python` \n\n"
|
||||
LlamaCPP = NotInstalled("llama.cpp", "llama-cpp-python", custom_msg=msg)
|
||||
|
||||
# Text Generation using transformers
|
||||
try:
|
||||
from bertopic.representation._textgeneration import TextGeneration
|
||||
except ModuleNotFoundError:
|
||||
msg = "`pip install bertopic` without `--no-deps` \n\n"
|
||||
TextGeneration = NotInstalled("TextGeneration", "transformers", custom_msg=msg)
|
||||
|
||||
# Zero-shot classification using transformers
|
||||
try:
|
||||
from bertopic.representation._zeroshot import ZeroShotClassification
|
||||
except ModuleNotFoundError:
|
||||
msg = "`pip install bertopic` without `--no-deps` \n\n"
|
||||
ZeroShotClassification = NotInstalled("ZeroShotClassification", "transformers", custom_msg=msg)
|
||||
|
||||
# OpenAI Generator
|
||||
try:
|
||||
from bertopic.representation._openai import OpenAI
|
||||
except ModuleNotFoundError:
|
||||
msg = "`pip install openai` \n\n"
|
||||
OpenAI = NotInstalled("OpenAI", "openai", custom_msg=msg)
|
||||
|
||||
# LiteLLM Generator
|
||||
try:
|
||||
from bertopic.representation._litellm import LiteLLM
|
||||
except ModuleNotFoundError:
|
||||
msg = "`pip install litellm` \n\n"
|
||||
LiteLLM = NotInstalled("LiteLLM", "litellm", custom_msg=msg)
|
||||
|
||||
# LangChain Generator
|
||||
try:
|
||||
from bertopic.representation._langchain import LangChain
|
||||
except ModuleNotFoundError:
|
||||
msg = "`pip install langchain` \n\n"
|
||||
LangChain = NotInstalled("langchain", "langchain", custom_msg=msg)
|
||||
|
||||
# POS using Spacy
|
||||
try:
|
||||
from bertopic.representation._pos import PartOfSpeech
|
||||
except ModuleNotFoundError:
|
||||
PartOfSpeech = NotInstalled("Part of Speech with Spacy", "spacy")
|
||||
|
||||
# Multimodal
|
||||
try:
|
||||
from bertopic.representation._visual import VisualRepresentation
|
||||
except ModuleNotFoundError:
|
||||
VisualRepresentation = NotInstalled("a visual representation model", "vision")
|
||||
|
||||
|
||||
__all__ = [
|
||||
"BaseRepresentation",
|
||||
"TextGeneration",
|
||||
"ZeroShotClassification",
|
||||
"KeyBERTInspired",
|
||||
"PartOfSpeech",
|
||||
"MaximalMarginalRelevance",
|
||||
"Cohere",
|
||||
"OpenAI",
|
||||
"LangChain",
|
||||
"LiteLLM",
|
||||
"LlamaCPP",
|
||||
"VisualRepresentation",
|
||||
]
|
||||
@@ -0,0 +1,40 @@
|
||||
import pandas as pd
|
||||
from scipy.sparse import csr_matrix
|
||||
from sklearn.base import BaseEstimator
|
||||
from typing import Mapping, List, Tuple
|
||||
|
||||
|
||||
class BaseRepresentation(BaseEstimator):
|
||||
"""The base representation model for fine-tuning topic representations."""
|
||||
|
||||
def extract_topics(
|
||||
self,
|
||||
topic_model,
|
||||
documents: pd.DataFrame,
|
||||
c_tf_idf: csr_matrix,
|
||||
topics: Mapping[str, List[Tuple[str, float]]],
|
||||
) -> Mapping[str, List[Tuple[str, float]]]:
|
||||
"""Extract topics.
|
||||
|
||||
Each representation model that inherits this class will have
|
||||
its arguments (topic_model, documents, c_tf_idf, topics)
|
||||
automatically passed. Therefore, the representation model
|
||||
will only have access to the information about topics related
|
||||
to those arguments.
|
||||
|
||||
Arguments:
|
||||
topic_model: The BERTopic model that is fitted until topic
|
||||
representations are calculated.
|
||||
documents: A dataframe with columns "Document" and "Topic"
|
||||
that contains all documents with each corresponding
|
||||
topic.
|
||||
c_tf_idf: A c-TF-IDF representation that is typically
|
||||
identical to `topic_model.c_tf_idf_` except for
|
||||
dynamic, class-based, and hierarchical topic modeling
|
||||
where it is calculated on a subset of the documents.
|
||||
topics: A dictionary with topic (key) and tuple of word and
|
||||
weight (value) as calculated by c-TF-IDF. This is the
|
||||
default topics that are returned if no representation
|
||||
model is used.
|
||||
"""
|
||||
return topic_model.topic_representations_
|
||||
@@ -0,0 +1,209 @@
|
||||
import time
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
from scipy.sparse import csr_matrix
|
||||
from typing import Mapping, List, Tuple, Union, Callable
|
||||
from bertopic.representation._base import BaseRepresentation
|
||||
from bertopic.representation._utils import truncate_document, validate_truncate_document_parameters
|
||||
|
||||
|
||||
DEFAULT_PROMPT = """
|
||||
This is a list of texts where each collection of texts describe a topic. After each collection of texts, the name of the topic they represent is mentioned as a short-highly-descriptive title
|
||||
---
|
||||
Topic:
|
||||
Sample texts from this topic:
|
||||
- Traditional diets in most cultures were primarily plant-based with a little meat on top, but with the rise of industrial style meat production and factory farming, meat has become a staple food.
|
||||
- Meat, but especially beef, is the word food in terms of emissions.
|
||||
- Eating meat doesn't make you a bad person, not eating meat doesn't make you a good one.
|
||||
|
||||
Keywords: meat beef eat eating emissions steak food health processed chicken
|
||||
Topic name: Environmental impacts of eating meat
|
||||
---
|
||||
Topic:
|
||||
Sample texts from this topic:
|
||||
- I have ordered the product weeks ago but it still has not arrived!
|
||||
- The website mentions that it only takes a couple of days to deliver but I still have not received mine.
|
||||
- I got a message stating that I received the monitor but that is not true!
|
||||
- It took a month longer to deliver than was advised...
|
||||
|
||||
Keywords: deliver weeks product shipping long delivery received arrived arrive week
|
||||
Topic name: Shipping and delivery issues
|
||||
---
|
||||
Topic:
|
||||
Sample texts from this topic:
|
||||
[DOCUMENTS]
|
||||
Keywords: [KEYWORDS]
|
||||
Topic name:"""
|
||||
|
||||
DEFAULT_SYSTEM_PROMPT = "You are an assistant that extracts high-level topics from texts."
|
||||
|
||||
|
||||
class Cohere(BaseRepresentation):
|
||||
"""Use the Cohere API to generate topic labels based on their
|
||||
generative model.
|
||||
|
||||
Find more about their models here:
|
||||
https://docs.cohere.ai/docs
|
||||
|
||||
Arguments:
|
||||
client: A `cohere.Client`
|
||||
model: Model to use within Cohere, defaults to `"xlarge"`.
|
||||
prompt: The prompt to be used in the model. If no prompt is given,
|
||||
`self.default_prompt_` is used instead.
|
||||
NOTE: Use `"[KEYWORDS]"` and `"[DOCUMENTS]"` in the prompt
|
||||
to decide where the keywords and documents need to be
|
||||
inserted.
|
||||
system_prompt: The system prompt to be used in the model. If no system prompt is given,
|
||||
`self.default_system_prompt_` is used instead.
|
||||
delay_in_seconds: The delay in seconds between consecutive prompts
|
||||
in order to prevent RateLimitErrors.
|
||||
nr_docs: The number of documents to pass to OpenAI if a prompt
|
||||
with the `["DOCUMENTS"]` tag is used.
|
||||
diversity: The diversity of documents to pass to OpenAI.
|
||||
Accepts values between 0 and 1. A higher
|
||||
values results in passing more diverse documents
|
||||
whereas lower values passes more similar documents.
|
||||
doc_length: The maximum length of each document. If a document is longer,
|
||||
it will be truncated. If None, the entire document is passed.
|
||||
tokenizer: The tokenizer used to calculate to split the document into segments
|
||||
used to count the length of a document.
|
||||
* If tokenizer is 'char', then the document is split up
|
||||
into characters which are counted to adhere to `doc_length`
|
||||
* If tokenizer is 'whitespace', the document is split up
|
||||
into words separated by whitespaces. These words are counted
|
||||
and truncated depending on `doc_length`
|
||||
* If tokenizer is 'vectorizer', then the internal CountVectorizer
|
||||
is used to tokenize the document. These tokens are counted
|
||||
and truncated depending on `doc_length`
|
||||
* If tokenizer is a callable, then that callable is used to tokenize
|
||||
the document. These tokens are counted and truncated depending
|
||||
on `doc_length`
|
||||
|
||||
Usage:
|
||||
|
||||
To use this, you will need to install cohere first:
|
||||
|
||||
`pip install cohere`
|
||||
|
||||
Then, get yourself an API key and use Cohere's API as follows:
|
||||
|
||||
```python
|
||||
import cohere
|
||||
from bertopic.representation import Cohere
|
||||
from bertopic import BERTopic
|
||||
|
||||
# Create your representation model
|
||||
co = cohere.Client(my_api_key)
|
||||
representation_model = Cohere(co)
|
||||
|
||||
# Use the representation model in BERTopic on top of the default pipeline
|
||||
topic_model = BERTopic(representation_model=representation_model)
|
||||
```
|
||||
|
||||
You can also use a custom prompt:
|
||||
|
||||
```python
|
||||
prompt = "I have the following documents: [DOCUMENTS]. What topic do they contain?"
|
||||
representation_model = Cohere(co, prompt=prompt)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client,
|
||||
model: str = "command-r",
|
||||
prompt: str = None,
|
||||
system_prompt: str = None,
|
||||
delay_in_seconds: float = None,
|
||||
nr_docs: int = 4,
|
||||
diversity: float = None,
|
||||
doc_length: int = None,
|
||||
tokenizer: Union[str, Callable] = None,
|
||||
):
|
||||
self.client = client
|
||||
self.model = model
|
||||
self.prompt = prompt if prompt is not None else DEFAULT_PROMPT
|
||||
self.system_prompt = system_prompt if system_prompt is not None else DEFAULT_SYSTEM_PROMPT
|
||||
self.default_prompt_ = DEFAULT_PROMPT
|
||||
self.default_system_prompt_ = DEFAULT_SYSTEM_PROMPT
|
||||
self.delay_in_seconds = delay_in_seconds
|
||||
self.nr_docs = nr_docs
|
||||
self.diversity = diversity
|
||||
self.doc_length = doc_length
|
||||
self.tokenizer = tokenizer
|
||||
validate_truncate_document_parameters(self.tokenizer, self.doc_length)
|
||||
|
||||
self.prompts_ = []
|
||||
|
||||
def extract_topics(
|
||||
self,
|
||||
topic_model,
|
||||
documents: pd.DataFrame,
|
||||
c_tf_idf: csr_matrix,
|
||||
topics: Mapping[str, List[Tuple[str, float]]],
|
||||
) -> Mapping[str, List[Tuple[str, float]]]:
|
||||
"""Extract topics.
|
||||
|
||||
Arguments:
|
||||
topic_model: Not used
|
||||
documents: Not used
|
||||
c_tf_idf: Not used
|
||||
topics: The candidate topics as calculated with c-TF-IDF
|
||||
|
||||
Returns:
|
||||
updated_topics: Updated topic representations
|
||||
"""
|
||||
# Extract the top 4 representative documents per topic
|
||||
repr_docs_mappings, _, _, _ = topic_model._extract_representative_docs(
|
||||
c_tf_idf, documents, topics, 500, self.nr_docs, self.diversity
|
||||
)
|
||||
|
||||
# Generate using Cohere's Language Model
|
||||
updated_topics = {}
|
||||
for topic, docs in tqdm(repr_docs_mappings.items(), disable=not topic_model.verbose):
|
||||
truncated_docs = [truncate_document(topic_model, self.doc_length, self.tokenizer, doc) for doc in docs]
|
||||
prompt = self._create_prompt(truncated_docs, topic, topics)
|
||||
self.prompts_.append(prompt)
|
||||
|
||||
# Delay
|
||||
if self.delay_in_seconds:
|
||||
time.sleep(self.delay_in_seconds)
|
||||
|
||||
request = self.client.chat(
|
||||
model=self.model,
|
||||
preamble=self.system_prompt,
|
||||
message=prompt,
|
||||
max_tokens=50,
|
||||
stop_sequences=["\n"],
|
||||
)
|
||||
label = request.text.strip()
|
||||
updated_topics[topic] = [(label, 1)] + [("", 0) for _ in range(9)]
|
||||
|
||||
return updated_topics
|
||||
|
||||
def _create_prompt(self, docs, topic, topics):
|
||||
keywords = list(zip(*topics[topic]))[0]
|
||||
|
||||
# Use the Default Chat Prompt
|
||||
if self.prompt == DEFAULT_PROMPT:
|
||||
prompt = self.prompt.replace("[KEYWORDS]", ", ".join(keywords))
|
||||
prompt = self._replace_documents(prompt, docs)
|
||||
|
||||
# Use a custom prompt that leverages keywords, documents or both using
|
||||
# custom tags, namely [KEYWORDS] and [DOCUMENTS] respectively
|
||||
else:
|
||||
prompt = self.prompt
|
||||
if "[KEYWORDS]" in prompt:
|
||||
prompt = prompt.replace("[KEYWORDS]", ", ".join(keywords))
|
||||
if "[DOCUMENTS]" in prompt:
|
||||
prompt = self._replace_documents(prompt, docs)
|
||||
|
||||
return prompt
|
||||
|
||||
@staticmethod
|
||||
def _replace_documents(prompt, docs):
|
||||
to_replace = ""
|
||||
for doc in docs:
|
||||
to_replace += f"- {doc}\n"
|
||||
prompt = prompt.replace("[DOCUMENTS]", to_replace)
|
||||
return prompt
|
||||
@@ -0,0 +1,222 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from packaging import version
|
||||
from scipy.sparse import csr_matrix
|
||||
from typing import Mapping, List, Tuple, Union
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
from bertopic.representation._base import BaseRepresentation
|
||||
from sklearn import __version__ as sklearn_version
|
||||
|
||||
|
||||
class KeyBERTInspired(BaseRepresentation):
|
||||
def __init__(
|
||||
self,
|
||||
top_n_words: int = 10,
|
||||
nr_repr_docs: int = 5,
|
||||
nr_samples: int = 500,
|
||||
nr_candidate_words: int = 100,
|
||||
random_state: int = 42,
|
||||
):
|
||||
"""Use a KeyBERT-like model to fine-tune the topic representations.
|
||||
|
||||
The algorithm follows KeyBERT but does some optimization in
|
||||
order to speed up inference.
|
||||
|
||||
The steps are as follows. First, we extract the top n representative
|
||||
documents per topic. To extract the representative documents, we
|
||||
randomly sample a number of candidate documents per cluster
|
||||
which is controlled by the `nr_samples` parameter. Then,
|
||||
the top n representative documents are extracted by calculating
|
||||
the c-TF-IDF representation for the candidate documents and finding,
|
||||
through cosine similarity, which are closest to the topic c-TF-IDF representation.
|
||||
Next, the top n words per topic are extracted based on their
|
||||
c-TF-IDF representation, which is controlled by the `nr_repr_docs`
|
||||
parameter.
|
||||
|
||||
Then, we extract the embeddings for words and representative documents
|
||||
and create topic embeddings by averaging the representative documents.
|
||||
Finally, the most similar words to each topic are extracted by
|
||||
calculating the cosine similarity between word and topic embeddings.
|
||||
|
||||
Arguments:
|
||||
top_n_words: The top n words to extract per topic.
|
||||
nr_repr_docs: The number of representative documents to extract per cluster.
|
||||
nr_samples: The number of candidate documents to extract per cluster.
|
||||
nr_candidate_words: The number of candidate words per cluster.
|
||||
random_state: The random state for randomly sampling candidate documents.
|
||||
|
||||
Usage:
|
||||
|
||||
```python
|
||||
from bertopic.representation import KeyBERTInspired
|
||||
from bertopic import BERTopic
|
||||
|
||||
# Create your representation model
|
||||
representation_model = KeyBERTInspired()
|
||||
|
||||
# Use the representation model in BERTopic on top of the default pipeline
|
||||
topic_model = BERTopic(representation_model=representation_model)
|
||||
```
|
||||
"""
|
||||
self.top_n_words = top_n_words
|
||||
self.nr_repr_docs = nr_repr_docs
|
||||
self.nr_samples = nr_samples
|
||||
self.nr_candidate_words = nr_candidate_words
|
||||
self.random_state = random_state
|
||||
|
||||
def extract_topics(
|
||||
self,
|
||||
topic_model,
|
||||
documents: pd.DataFrame,
|
||||
c_tf_idf: csr_matrix,
|
||||
topics: Mapping[str, List[Tuple[str, float]]],
|
||||
embeddings: np.ndarray = None,
|
||||
) -> Mapping[str, List[Tuple[str, float]]]:
|
||||
"""Extract topics.
|
||||
|
||||
Arguments:
|
||||
topic_model: A BERTopic model
|
||||
documents: All input documents
|
||||
c_tf_idf: The topic c-TF-IDF representation
|
||||
topics: The candidate topics as calculated with c-TF-IDF
|
||||
embeddings: Pre-trained document embeddings. These can be used
|
||||
instead of an embedding model
|
||||
|
||||
Returns:
|
||||
updated_topics: Updated topic representations
|
||||
"""
|
||||
# We extract the top n representative documents per class
|
||||
_, representative_docs, repr_doc_indices, _ = topic_model._extract_representative_docs(
|
||||
c_tf_idf, documents, topics, self.nr_samples, self.nr_repr_docs
|
||||
)
|
||||
|
||||
# If document embeddings are precomputed, extract the embeddings of the representative documents based on repr_doc_indices
|
||||
repr_embeddings = None
|
||||
if embeddings is not None:
|
||||
repr_embeddings = [embeddings[index] for index in np.concatenate(repr_doc_indices)]
|
||||
|
||||
# We extract the top n words per class
|
||||
topics = self._extract_candidate_words(topic_model, c_tf_idf, topics)
|
||||
|
||||
# We calculate the similarity between word and document embeddings and create
|
||||
# topic embeddings from the representative document embeddings
|
||||
sim_matrix, words = self._extract_embeddings(
|
||||
topic_model, topics, representative_docs, repr_doc_indices, repr_embeddings
|
||||
)
|
||||
# Find the best matching words based on the similarity matrix for each topic
|
||||
updated_topics = self._extract_top_words(words, topics, sim_matrix)
|
||||
|
||||
return updated_topics
|
||||
|
||||
def _extract_candidate_words(
|
||||
self,
|
||||
topic_model,
|
||||
c_tf_idf: csr_matrix,
|
||||
topics: Mapping[str, List[Tuple[str, float]]],
|
||||
) -> Mapping[str, List[Tuple[str, float]]]:
|
||||
"""For each topic, extract candidate words based on the c-TF-IDF
|
||||
representation.
|
||||
|
||||
Arguments:
|
||||
topic_model: A BERTopic model
|
||||
c_tf_idf: The topic c-TF-IDF representation
|
||||
topics: The top words per topic
|
||||
|
||||
Returns:
|
||||
topics: The `self.top_n_words` per topic
|
||||
"""
|
||||
labels = [int(label) for label in sorted(list(topics.keys()))]
|
||||
|
||||
# Scikit-Learn Deprecation: get_feature_names is deprecated in 1.0
|
||||
# and will be removed in 1.2. Please use get_feature_names_out instead.
|
||||
if version.parse(sklearn_version) >= version.parse("1.0.0"):
|
||||
words = topic_model.vectorizer_model.get_feature_names_out()
|
||||
else:
|
||||
words = topic_model.vectorizer_model.get_feature_names()
|
||||
|
||||
indices = topic_model._top_n_idx_sparse(c_tf_idf, self.nr_candidate_words)
|
||||
scores = topic_model._top_n_values_sparse(c_tf_idf, indices)
|
||||
sorted_indices = np.argsort(scores, 1)
|
||||
indices = np.take_along_axis(indices, sorted_indices, axis=1)
|
||||
scores = np.take_along_axis(scores, sorted_indices, axis=1)
|
||||
|
||||
# Get top 30 words per topic based on c-TF-IDF score
|
||||
topics = {
|
||||
label: [
|
||||
(words[word_index], score) if word_index is not None and score > 0 else ("", 0.00001)
|
||||
for word_index, score in zip(indices[index][::-1], scores[index][::-1])
|
||||
]
|
||||
for index, label in enumerate(labels)
|
||||
}
|
||||
topics = {label: list(zip(*values[: self.nr_candidate_words]))[0] for label, values in topics.items()}
|
||||
|
||||
return topics
|
||||
|
||||
def _extract_embeddings(
|
||||
self,
|
||||
topic_model,
|
||||
topics: Mapping[str, List[Tuple[str, float]]],
|
||||
representative_docs: List[str],
|
||||
repr_doc_indices: List[List[int]],
|
||||
repr_embeddings: np.ndarray = None,
|
||||
) -> Union[np.ndarray, List[str]]:
|
||||
"""Extract the representative document embeddings and create topic embeddings.
|
||||
Then extract word embeddings and calculate the cosine similarity between topic
|
||||
embeddings and the word embeddings. Topic embeddings are the average of
|
||||
representative document embeddings.
|
||||
|
||||
Arguments:
|
||||
topic_model: A BERTopic model
|
||||
topics: The top words per topic
|
||||
representative_docs: A flat list of representative documents
|
||||
repr_doc_indices: The indices of representative documents
|
||||
that belong to each topic
|
||||
repr_embeddings: Embeddings of respective representative_docs
|
||||
|
||||
Returns:
|
||||
sim: The similarity matrix between word and topic embeddings
|
||||
vocab: The complete vocabulary of input documents
|
||||
"""
|
||||
# Calculate representative document embeddings if there are no precomputed embeddings.
|
||||
if repr_embeddings is None:
|
||||
repr_embeddings = topic_model._extract_embeddings(representative_docs, method="document", verbose=False)
|
||||
|
||||
topic_embeddings = [np.mean(repr_embeddings[i[0] : i[-1] + 1], axis=0) for i in repr_doc_indices]
|
||||
|
||||
# Calculate word embeddings and extract best matching with updated topic_embeddings
|
||||
vocab = list(set([word for words in topics.values() for word in words]))
|
||||
word_embeddings = topic_model._extract_embeddings(vocab, method="document", verbose=False)
|
||||
sim = cosine_similarity(topic_embeddings, word_embeddings)
|
||||
|
||||
return sim, vocab
|
||||
|
||||
def _extract_top_words(
|
||||
self,
|
||||
vocab: List[str],
|
||||
topics: Mapping[str, List[Tuple[str, float]]],
|
||||
sim: np.ndarray,
|
||||
) -> Mapping[str, List[Tuple[str, float]]]:
|
||||
"""Extract the top n words per topic based on the
|
||||
similarity matrix between topics and words.
|
||||
|
||||
Arguments:
|
||||
vocab: The complete vocabulary of input documents
|
||||
labels: All topic labels
|
||||
topics: The top words per topic
|
||||
sim: The similarity matrix between word and topic embeddings
|
||||
|
||||
Returns:
|
||||
updated_topics: The updated topic representations
|
||||
"""
|
||||
labels = [int(label) for label in sorted(list(topics.keys()))]
|
||||
updated_topics = {}
|
||||
for i, topic in enumerate(labels):
|
||||
indices = [vocab.index(word) for word in topics[topic]]
|
||||
values = sim[:, indices][i]
|
||||
word_indices = [indices[index] for index in np.argsort(values)[-self.top_n_words :]]
|
||||
updated_topics[topic] = [
|
||||
(vocab[index], val) for val, index in zip(np.sort(values)[-self.top_n_words :], word_indices)
|
||||
][::-1]
|
||||
|
||||
return updated_topics
|
||||
@@ -0,0 +1,213 @@
|
||||
import pandas as pd
|
||||
from langchain.docstore.document import Document
|
||||
from scipy.sparse import csr_matrix
|
||||
from typing import Callable, Mapping, List, Tuple, Union
|
||||
|
||||
from bertopic.representation._base import BaseRepresentation
|
||||
from bertopic.representation._utils import truncate_document, validate_truncate_document_parameters
|
||||
|
||||
DEFAULT_PROMPT = "What are these documents about? Please give a single label."
|
||||
|
||||
|
||||
class LangChain(BaseRepresentation):
|
||||
"""Using chains in langchain to generate topic labels.
|
||||
|
||||
The classic example uses `langchain.chains.question_answering.load_qa_chain`.
|
||||
This returns a chain that takes a list of documents and a question as input.
|
||||
|
||||
You can also use Runnables such as those composed using the LangChain Expression Language.
|
||||
|
||||
Arguments:
|
||||
chain: The langchain chain or Runnable with a `batch` method.
|
||||
Input keys must be `input_documents` and `question`.
|
||||
Output key must be `output_text`.
|
||||
prompt: The prompt to be used in the model. If no prompt is given,
|
||||
`self.default_prompt_` is used instead.
|
||||
NOTE: Use `"[KEYWORDS]"` in the prompt
|
||||
to decide where the keywords need to be
|
||||
inserted. Keywords won't be included unless
|
||||
indicated. Unlike other representation models,
|
||||
Langchain does not use the `"[DOCUMENTS]"` tag
|
||||
to insert documents into the prompt. The load_qa_chain function
|
||||
formats the representative documents within the prompt.
|
||||
nr_docs: The number of documents to pass to LangChain
|
||||
diversity: The diversity of documents to pass to LangChain.
|
||||
Accepts values between 0 and 1. A higher
|
||||
values results in passing more diverse documents
|
||||
whereas lower values passes more similar documents.
|
||||
doc_length: The maximum length of each document. If a document is longer,
|
||||
it will be truncated. If None, the entire document is passed.
|
||||
tokenizer: The tokenizer used to calculate to split the document into segments
|
||||
used to count the length of a document.
|
||||
* If tokenizer is 'char', then the document is split up
|
||||
into characters which are counted to adhere to `doc_length`
|
||||
* If tokenizer is 'whitespace', the document is split up
|
||||
into words separated by whitespaces. These words are counted
|
||||
and truncated depending on `doc_length`
|
||||
* If tokenizer is 'vectorizer', then the internal CountVectorizer
|
||||
is used to tokenize the document. These tokens are counted
|
||||
and truncated depending on `doc_length`. They are decoded with
|
||||
whitespaces.
|
||||
* If tokenizer is a callable, then that callable is used to tokenize
|
||||
the document. These tokens are counted and truncated depending
|
||||
on `doc_length`
|
||||
chain_config: The configuration for the langchain chain. Can be used to set options
|
||||
like max_concurrency to avoid rate limiting errors.
|
||||
Usage:
|
||||
|
||||
To use this, you will need to install the langchain package first.
|
||||
Additionally, you will need an underlying LLM to support langchain,
|
||||
like openai:
|
||||
|
||||
`pip install langchain`
|
||||
`pip install openai`
|
||||
|
||||
Then, you can create your chain as follows:
|
||||
|
||||
```python
|
||||
from langchain.chains.question_answering import load_qa_chain
|
||||
from langchain.llms import OpenAI
|
||||
chain = load_qa_chain(OpenAI(temperature=0, openai_api_key=my_openai_api_key), chain_type="stuff")
|
||||
```
|
||||
|
||||
Finally, you can pass the chain to BERTopic as follows:
|
||||
|
||||
```python
|
||||
from bertopic.representation import LangChain
|
||||
|
||||
# Create your representation model
|
||||
representation_model = LangChain(chain)
|
||||
|
||||
# Use the representation model in BERTopic on top of the default pipeline
|
||||
topic_model = BERTopic(representation_model=representation_model)
|
||||
```
|
||||
|
||||
You can also use a custom prompt:
|
||||
|
||||
```python
|
||||
prompt = "What are these documents about? Please give a single label."
|
||||
representation_model = LangChain(chain, prompt=prompt)
|
||||
```
|
||||
|
||||
You can also use a Runnable instead of a chain.
|
||||
The example below uses the LangChain Expression Language:
|
||||
|
||||
```python
|
||||
from bertopic.representation import LangChain
|
||||
from langchain.chains.question_answering import load_qa_chain
|
||||
from langchain.chat_models import ChatAnthropic
|
||||
from langchain.schema.document import Document
|
||||
from langchain.schema.runnable import RunnablePassthrough
|
||||
from langchain_experimental.data_anonymizer.presidio import PresidioReversibleAnonymizer
|
||||
|
||||
prompt = ...
|
||||
llm = ...
|
||||
|
||||
# We will construct a special privacy-preserving chain using Microsoft Presidio
|
||||
|
||||
pii_handler = PresidioReversibleAnonymizer(analyzed_fields=["PERSON"])
|
||||
|
||||
chain = (
|
||||
{
|
||||
"input_documents": (
|
||||
lambda inp: [
|
||||
Document(
|
||||
page_content=pii_handler.anonymize(
|
||||
d.page_content,
|
||||
language="en",
|
||||
),
|
||||
)
|
||||
for d in inp["input_documents"]
|
||||
]
|
||||
),
|
||||
"question": RunnablePassthrough(),
|
||||
}
|
||||
| load_qa_chain(representation_llm, chain_type="stuff")
|
||||
| (lambda output: {"output_text": pii_handler.deanonymize(output["output_text"])})
|
||||
)
|
||||
|
||||
representation_model = LangChain(chain, prompt=representation_prompt)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chain,
|
||||
prompt: str = None,
|
||||
nr_docs: int = 4,
|
||||
diversity: float = None,
|
||||
doc_length: int = None,
|
||||
tokenizer: Union[str, Callable] = None,
|
||||
chain_config=None,
|
||||
):
|
||||
self.chain = chain
|
||||
self.prompt = prompt if prompt is not None else DEFAULT_PROMPT
|
||||
self.default_prompt_ = DEFAULT_PROMPT
|
||||
self.chain_config = chain_config
|
||||
self.nr_docs = nr_docs
|
||||
self.diversity = diversity
|
||||
self.doc_length = doc_length
|
||||
self.tokenizer = tokenizer
|
||||
validate_truncate_document_parameters(self.tokenizer, self.doc_length)
|
||||
|
||||
def extract_topics(
|
||||
self,
|
||||
topic_model,
|
||||
documents: pd.DataFrame,
|
||||
c_tf_idf: csr_matrix,
|
||||
topics: Mapping[str, List[Tuple[str, float]]],
|
||||
) -> Mapping[str, List[Tuple[str, int]]]:
|
||||
"""Extract topics.
|
||||
|
||||
Arguments:
|
||||
topic_model: A BERTopic model
|
||||
documents: All input documents
|
||||
c_tf_idf: The topic c-TF-IDF representation
|
||||
topics: The candidate topics as calculated with c-TF-IDF
|
||||
|
||||
Returns:
|
||||
updated_topics: Updated topic representations
|
||||
"""
|
||||
# Extract the top 4 representative documents per topic
|
||||
repr_docs_mappings, _, _, _ = topic_model._extract_representative_docs(
|
||||
c_tf_idf=c_tf_idf,
|
||||
documents=documents,
|
||||
topics=topics,
|
||||
nr_samples=500,
|
||||
nr_repr_docs=self.nr_docs,
|
||||
diversity=self.diversity,
|
||||
)
|
||||
|
||||
# Generate label using langchain's batch functionality
|
||||
chain_docs: List[List[Document]] = [
|
||||
[
|
||||
Document(page_content=truncate_document(topic_model, self.doc_length, self.tokenizer, doc))
|
||||
for doc in docs
|
||||
]
|
||||
for docs in repr_docs_mappings.values()
|
||||
]
|
||||
|
||||
# `self.chain` must take `input_documents` and `question` as input keys
|
||||
# Use a custom prompt that leverages keywords, using the tag: [KEYWORDS]
|
||||
if "[KEYWORDS]" in self.prompt:
|
||||
prompts = []
|
||||
for topic in topics:
|
||||
keywords = list(zip(*topics[topic]))[0]
|
||||
prompt = self.prompt.replace("[KEYWORDS]", ", ".join(keywords))
|
||||
prompts.append(prompt)
|
||||
|
||||
inputs = [{"input_documents": docs, "question": prompt} for docs, prompt in zip(chain_docs, prompts)]
|
||||
|
||||
else:
|
||||
inputs = [{"input_documents": docs, "question": self.prompt} for docs in chain_docs]
|
||||
|
||||
# `self.chain` must return a dict with an `output_text` key
|
||||
# same output key as the `StuffDocumentsChain` returned by `load_qa_chain`
|
||||
outputs = self.chain.batch(inputs=inputs, config=self.chain_config)
|
||||
labels = [output["output_text"].strip() for output in outputs]
|
||||
|
||||
updated_topics = {
|
||||
topic: [(label, 1)] + [("", 0) for _ in range(9)] for topic, label in zip(repr_docs_mappings.keys(), labels)
|
||||
}
|
||||
|
||||
return updated_topics
|
||||
@@ -0,0 +1,176 @@
|
||||
import time
|
||||
from litellm import completion
|
||||
import pandas as pd
|
||||
from scipy.sparse import csr_matrix
|
||||
from typing import Mapping, List, Tuple, Any
|
||||
from bertopic.representation._base import BaseRepresentation
|
||||
from bertopic.representation._utils import retry_with_exponential_backoff
|
||||
|
||||
|
||||
DEFAULT_PROMPT = """
|
||||
I have a topic that contains the following documents:
|
||||
[DOCUMENTS]
|
||||
The topic is described by the following keywords: [KEYWORDS]
|
||||
Based on the information above, extract a short topic label in the following format:
|
||||
topic: <topic label>
|
||||
"""
|
||||
|
||||
|
||||
class LiteLLM(BaseRepresentation):
|
||||
"""Using the LiteLLM API to generate topic labels.
|
||||
|
||||
For an overview of models see:
|
||||
https://docs.litellm.ai/docs/providers
|
||||
|
||||
Arguments:
|
||||
model: Model to use. Defaults to OpenAI's "gpt-3.5-turbo".
|
||||
generator_kwargs: Kwargs passed to `litellm.completion`.
|
||||
prompt: The prompt to be used in the model. If no prompt is given,
|
||||
`self.default_prompt_` is used instead.
|
||||
NOTE: Use `"[KEYWORDS]"` and `"[DOCUMENTS]"` in the prompt
|
||||
to decide where the keywords and documents need to be
|
||||
inserted.
|
||||
delay_in_seconds: The delay in seconds between consecutive prompts
|
||||
in order to prevent RateLimitErrors.
|
||||
exponential_backoff: Retry requests with a random exponential backoff.
|
||||
A short sleep is used when a rate limit error is hit,
|
||||
then the requests is retried. Increase the sleep length
|
||||
if errors are hit until 10 unsuccesfull requests.
|
||||
If True, overrides `delay_in_seconds`.
|
||||
nr_docs: The number of documents to pass to LiteLLM if a prompt
|
||||
with the `["DOCUMENTS"]` tag is used.
|
||||
diversity: The diversity of documents to pass to LiteLLM.
|
||||
Accepts values between 0 and 1. A higher
|
||||
values results in passing more diverse documents
|
||||
whereas lower values passes more similar documents.
|
||||
|
||||
Usage:
|
||||
|
||||
To use this, you will need to install the litellm package first:
|
||||
|
||||
`pip install litellm`
|
||||
|
||||
Then, get yourself an API key of any provider (for instance OpenAI) and use it as follows:
|
||||
|
||||
```python
|
||||
import os
|
||||
from bertopic.representation import LiteLLM
|
||||
from bertopic import BERTopic
|
||||
|
||||
# set ENV variables
|
||||
os.environ["OPENAI_API_KEY"] = "your-openai-key"
|
||||
|
||||
# Create your representation model
|
||||
representation_model = LiteLLM(model="gpt-3.5-turbo")
|
||||
|
||||
# Use the representation model in BERTopic on top of the default pipeline
|
||||
topic_model = BERTopic(representation_model=representation_model)
|
||||
```
|
||||
|
||||
You can also use a custom prompt:
|
||||
|
||||
```python
|
||||
prompt = "I have the following documents: [DOCUMENTS] \nThese documents are about the following topic: '"
|
||||
representation_model = LiteLLM(model="gpt", prompt=prompt)
|
||||
```
|
||||
""" # noqa: D301
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "gpt-3.5-turbo",
|
||||
prompt: str = None,
|
||||
generator_kwargs: Mapping[str, Any] = {},
|
||||
delay_in_seconds: float = None,
|
||||
exponential_backoff: bool = False,
|
||||
nr_docs: int = 4,
|
||||
diversity: float = None,
|
||||
):
|
||||
self.model = model
|
||||
self.prompt = prompt if prompt else DEFAULT_PROMPT
|
||||
self.default_prompt_ = DEFAULT_PROMPT
|
||||
self.delay_in_seconds = delay_in_seconds
|
||||
self.exponential_backoff = exponential_backoff
|
||||
self.nr_docs = nr_docs
|
||||
self.diversity = diversity
|
||||
|
||||
self.generator_kwargs = generator_kwargs
|
||||
if self.generator_kwargs.get("model"):
|
||||
self.model = generator_kwargs.get("model")
|
||||
if self.generator_kwargs.get("prompt"):
|
||||
del self.generator_kwargs["prompt"]
|
||||
|
||||
def extract_topics(
|
||||
self, topic_model, documents: pd.DataFrame, c_tf_idf: csr_matrix, topics: Mapping[str, List[Tuple[str, float]]]
|
||||
) -> Mapping[str, List[Tuple[str, float]]]:
|
||||
"""Extract topics.
|
||||
|
||||
Arguments:
|
||||
topic_model: A BERTopic model
|
||||
documents: All input documents
|
||||
c_tf_idf: The topic c-TF-IDF representation
|
||||
topics: The candidate topics as calculated with c-TF-IDF
|
||||
|
||||
Returns:
|
||||
updated_topics: Updated topic representations
|
||||
"""
|
||||
# Extract the top n representative documents per topic
|
||||
repr_docs_mappings, _, _, _ = topic_model._extract_representative_docs(
|
||||
c_tf_idf, documents, topics, 500, self.nr_docs, self.diversity
|
||||
)
|
||||
|
||||
# Generate using a (Large) Language Model
|
||||
updated_topics = {}
|
||||
for topic, docs in repr_docs_mappings.items():
|
||||
prompt = self._create_prompt(docs, topic, topics)
|
||||
|
||||
# Delay
|
||||
if self.delay_in_seconds:
|
||||
time.sleep(self.delay_in_seconds)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
kwargs = {"model": self.model, "messages": messages, **self.generator_kwargs}
|
||||
if self.exponential_backoff:
|
||||
response = chat_completions_with_backoff(**kwargs)
|
||||
else:
|
||||
response = completion(**kwargs)
|
||||
label = response["choices"][0]["message"]["content"].strip().replace("topic: ", "")
|
||||
|
||||
updated_topics[topic] = [(label, 1)]
|
||||
|
||||
return updated_topics
|
||||
|
||||
def _create_prompt(self, docs, topic, topics):
|
||||
keywords = list(zip(*topics[topic]))[0]
|
||||
|
||||
# Use the Default Chat Prompt
|
||||
if self.prompt == DEFAULT_PROMPT:
|
||||
prompt = self.prompt.replace("[KEYWORDS]", " ".join(keywords))
|
||||
prompt = self._replace_documents(prompt, docs)
|
||||
|
||||
# Use a custom prompt that leverages keywords, documents or both using
|
||||
# custom tags, namely [KEYWORDS] and [DOCUMENTS] respectively
|
||||
else:
|
||||
prompt = self.prompt
|
||||
if "[KEYWORDS]" in prompt:
|
||||
prompt = prompt.replace("[KEYWORDS]", " ".join(keywords))
|
||||
if "[DOCUMENTS]" in prompt:
|
||||
prompt = self._replace_documents(prompt, docs)
|
||||
|
||||
return prompt
|
||||
|
||||
@staticmethod
|
||||
def _replace_documents(prompt, docs):
|
||||
to_replace = ""
|
||||
for doc in docs:
|
||||
to_replace += f"- {doc[:255]}\n"
|
||||
prompt = prompt.replace("[DOCUMENTS]", to_replace)
|
||||
return prompt
|
||||
|
||||
|
||||
def chat_completions_with_backoff(**kwargs):
|
||||
return retry_with_exponential_backoff(
|
||||
completion,
|
||||
)(**kwargs)
|
||||
@@ -0,0 +1,215 @@
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
from scipy.sparse import csr_matrix
|
||||
from llama_cpp import Llama
|
||||
from typing import Mapping, List, Tuple, Any, Union, Callable
|
||||
from bertopic.representation._base import BaseRepresentation
|
||||
from bertopic.representation._utils import truncate_document, validate_truncate_document_parameters
|
||||
|
||||
|
||||
DEFAULT_PROMPT = """
|
||||
This is a list of texts where each collection of texts describe a topic. After each collection of texts, the name of the topic they represent is mentioned as a short-highly-descriptive title
|
||||
---
|
||||
Topic:
|
||||
Sample texts from this topic:
|
||||
- Traditional diets in most cultures were primarily plant-based with a little meat on top, but with the rise of industrial style meat production and factory farming, meat has become a staple food.
|
||||
- Meat, but especially beef, is the word food in terms of emissions.
|
||||
- Eating meat doesn't make you a bad person, not eating meat doesn't make you a good one.
|
||||
|
||||
Keywords: meat beef eat eating emissions steak food health processed chicken
|
||||
Topic name: Environmental impacts of eating meat
|
||||
---
|
||||
Topic:
|
||||
Sample texts from this topic:
|
||||
- I have ordered the product weeks ago but it still has not arrived!
|
||||
- The website mentions that it only takes a couple of days to deliver but I still have not received mine.
|
||||
- I got a message stating that I received the monitor but that is not true!
|
||||
- It took a month longer to deliver than was advised...
|
||||
|
||||
Keywords: deliver weeks product shipping long delivery received arrived arrive week
|
||||
Topic name: Shipping and delivery issues
|
||||
---
|
||||
Topic:
|
||||
Sample texts from this topic:
|
||||
[DOCUMENTS]
|
||||
Keywords: [KEYWORDS]
|
||||
Topic name:"""
|
||||
|
||||
DEFAULT_SYSTEM_PROMPT = "You are an assistant that extracts high-level topics from texts."
|
||||
|
||||
|
||||
class LlamaCPP(BaseRepresentation):
|
||||
"""A llama.cpp implementation to use as a representation model.
|
||||
|
||||
Arguments:
|
||||
model: Either a string pointing towards a local LLM or a
|
||||
`llama_cpp.Llama` object.
|
||||
prompt: The prompt to be used in the model. If no prompt is given,
|
||||
`self.default_prompt_` is used instead.
|
||||
NOTE: Use `"[KEYWORDS]"` and `"[DOCUMENTS]"` in the prompt
|
||||
to decide where the keywords and documents need to be
|
||||
inserted.
|
||||
system_prompt: The system prompt to be used in the model. If no system prompt is given,
|
||||
`self.default_system_prompt_` is used instead.
|
||||
pipeline_kwargs: Kwargs that you can pass to the `llama_cpp.Llama`
|
||||
when it is called such as `max_tokens` to be generated.
|
||||
nr_docs: The number of documents to pass to OpenAI if a prompt
|
||||
with the `["DOCUMENTS"]` tag is used.
|
||||
diversity: The diversity of documents to pass to OpenAI.
|
||||
Accepts values between 0 and 1. A higher
|
||||
values results in passing more diverse documents
|
||||
whereas lower values passes more similar documents.
|
||||
doc_length: The maximum length of each document. If a document is longer,
|
||||
it will be truncated. If None, the entire document is passed.
|
||||
tokenizer: The tokenizer used to calculate to split the document into segments
|
||||
used to count the length of a document.
|
||||
* If tokenizer is 'char', then the document is split up
|
||||
into characters which are counted to adhere to `doc_length`
|
||||
* If tokenizer is 'whitespace', the the document is split up
|
||||
into words separated by whitespaces. These words are counted
|
||||
and truncated depending on `doc_length`
|
||||
* If tokenizer is 'vectorizer', then the internal CountVectorizer
|
||||
is used to tokenize the document. These tokens are counted
|
||||
and truncated depending on `doc_length`
|
||||
* If tokenizer is a callable, then that callable is used to tokenize
|
||||
the document. These tokens are counted and truncated depending
|
||||
on `doc_length`
|
||||
|
||||
Usage:
|
||||
|
||||
To use a llama.cpp, first download the LLM:
|
||||
|
||||
```bash
|
||||
wget https://huggingface.co/TheBloke/zephyr-7B-alpha-GGUF/resolve/main/zephyr-7b-alpha.Q4_K_M.gguf
|
||||
```
|
||||
|
||||
Then, we can now use the model the model with BERTopic in just a couple of lines:
|
||||
|
||||
```python
|
||||
from bertopic import BERTopic
|
||||
from bertopic.representation import LlamaCPP
|
||||
|
||||
# Use llama.cpp to load in a 4-bit quantized version of Zephyr 7B Alpha
|
||||
representation_model = LlamaCPP("zephyr-7b-alpha.Q4_K_M.gguf")
|
||||
|
||||
# Create our BERTopic model
|
||||
topic_model = BERTopic(representation_model=representation_model, verbose=True)
|
||||
```
|
||||
|
||||
If you want to have more control over the LLMs parameters, you can run it like so:
|
||||
|
||||
```python
|
||||
from bertopic import BERTopic
|
||||
from bertopic.representation import LlamaCPP
|
||||
from llama_cpp import Llama
|
||||
|
||||
# Use llama.cpp to load in a 4-bit quantized version of Zephyr 7B Alpha
|
||||
llm = Llama(model_path="zephyr-7b-alpha.Q4_K_M.gguf", n_gpu_layers=-1, n_ctx=4096, stop="Q:")
|
||||
representation_model = LlamaCPP(llm)
|
||||
|
||||
# Create our BERTopic model
|
||||
topic_model = BERTopic(representation_model=representation_model, verbose=True)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Union[str, Llama],
|
||||
prompt: str = None,
|
||||
system_prompt: str = None,
|
||||
pipeline_kwargs: Mapping[str, Any] = {},
|
||||
nr_docs: int = 4,
|
||||
diversity: float = None,
|
||||
doc_length: int = None,
|
||||
tokenizer: Union[str, Callable] = None,
|
||||
):
|
||||
if isinstance(model, str):
|
||||
self.model = Llama(model_path=model, n_gpu_layers=-1, stop="\n", chat_format="ChatML")
|
||||
elif isinstance(model, Llama):
|
||||
self.model = model
|
||||
else:
|
||||
raise ValueError(
|
||||
"Make sure that the model that you"
|
||||
"pass is either a string referring to a"
|
||||
"local LLM or a ` llama_cpp.Llama` object."
|
||||
)
|
||||
self.prompt = prompt if prompt is not None else DEFAULT_PROMPT
|
||||
self.system_prompt = system_prompt if system_prompt is not None else DEFAULT_SYSTEM_PROMPT
|
||||
self.default_prompt_ = DEFAULT_PROMPT
|
||||
self.default_system_prompt_ = DEFAULT_SYSTEM_PROMPT
|
||||
self.pipeline_kwargs = pipeline_kwargs
|
||||
self.nr_docs = nr_docs
|
||||
self.diversity = diversity
|
||||
self.doc_length = doc_length
|
||||
self.tokenizer = tokenizer
|
||||
validate_truncate_document_parameters(self.tokenizer, self.doc_length)
|
||||
|
||||
self.prompts_ = []
|
||||
|
||||
def extract_topics(
|
||||
self,
|
||||
topic_model,
|
||||
documents: pd.DataFrame,
|
||||
c_tf_idf: csr_matrix,
|
||||
topics: Mapping[str, List[Tuple[str, float]]],
|
||||
) -> Mapping[str, List[Tuple[str, float]]]:
|
||||
"""Extract topic representations and return a single label.
|
||||
|
||||
Arguments:
|
||||
topic_model: A BERTopic model
|
||||
documents: Not used
|
||||
c_tf_idf: Not used
|
||||
topics: The candidate topics as calculated with c-TF-IDF
|
||||
|
||||
Returns:
|
||||
updated_topics: Updated topic representations
|
||||
"""
|
||||
# Extract the top 4 representative documents per topic
|
||||
repr_docs_mappings, _, _, _ = topic_model._extract_representative_docs(
|
||||
c_tf_idf, documents, topics, 500, self.nr_docs, self.diversity
|
||||
)
|
||||
|
||||
updated_topics = {}
|
||||
for topic, docs in tqdm(repr_docs_mappings.items(), disable=not topic_model.verbose):
|
||||
# Prepare prompt
|
||||
truncated_docs = [truncate_document(topic_model, self.doc_length, self.tokenizer, doc) for doc in docs]
|
||||
prompt = self._create_prompt(truncated_docs, topic, topics)
|
||||
self.prompts_.append(prompt)
|
||||
|
||||
# Extract result from generator and use that as label
|
||||
# topic_description = self.model(prompt, **self.pipeline_kwargs)["choices"]
|
||||
topic_description = self.model.create_chat_completion(
|
||||
messages=[{"role": "system", "content": self.system_prompt}, {"role": "user", "content": prompt}],
|
||||
**self.pipeline_kwargs,
|
||||
)
|
||||
label = topic_description["choices"][0]["message"]["content"].strip()
|
||||
updated_topics[topic] = [(label, 1)] + [("", 0) for _ in range(9)]
|
||||
|
||||
return updated_topics
|
||||
|
||||
def _create_prompt(self, docs, topic, topics):
|
||||
keywords = list(zip(*topics[topic]))[0]
|
||||
|
||||
# Use the Default Chat Prompt
|
||||
if self.prompt == DEFAULT_PROMPT:
|
||||
prompt = self.prompt.replace("[KEYWORDS]", ", ".join(keywords))
|
||||
prompt = self._replace_documents(prompt, docs)
|
||||
|
||||
# Use a custom prompt that leverages keywords, documents or both using
|
||||
# custom tags, namely [KEYWORDS] and [DOCUMENTS] respectively
|
||||
else:
|
||||
prompt = self.prompt
|
||||
if "[KEYWORDS]" in prompt:
|
||||
prompt = prompt.replace("[KEYWORDS]", ", ".join(keywords))
|
||||
if "[DOCUMENTS]" in prompt:
|
||||
prompt = self._replace_documents(prompt, docs)
|
||||
|
||||
return prompt
|
||||
|
||||
@staticmethod
|
||||
def _replace_documents(prompt, docs):
|
||||
to_replace = ""
|
||||
for doc in docs:
|
||||
to_replace += f"- {doc}\n"
|
||||
prompt = prompt.replace("[DOCUMENTS]", to_replace)
|
||||
return prompt
|
||||
@@ -0,0 +1,128 @@
|
||||
import warnings
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import List, Mapping, Tuple
|
||||
from scipy.sparse import csr_matrix
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
from bertopic.representation._base import BaseRepresentation
|
||||
|
||||
|
||||
class MaximalMarginalRelevance(BaseRepresentation):
|
||||
"""Calculate Maximal Marginal Relevance (MMR)
|
||||
between candidate keywords and the document.
|
||||
|
||||
MMR considers the similarity of keywords/keyphrases with the
|
||||
document, along with the similarity of already selected
|
||||
keywords and keyphrases. This results in a selection of keywords
|
||||
that maximize their within diversity with respect to the document.
|
||||
|
||||
Arguments:
|
||||
diversity: How diverse the select keywords/keyphrases are.
|
||||
Values range between 0 and 1 with 0 being not diverse at all
|
||||
and 1 being most diverse.
|
||||
top_n_words: The number of keywords/keyhprases to return
|
||||
|
||||
Usage:
|
||||
|
||||
```python
|
||||
from bertopic.representation import MaximalMarginalRelevance
|
||||
from bertopic import BERTopic
|
||||
|
||||
# Create your representation model
|
||||
representation_model = MaximalMarginalRelevance(diversity=0.3)
|
||||
|
||||
# Use the representation model in BERTopic on top of the default pipeline
|
||||
topic_model = BERTopic(representation_model=representation_model)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, diversity: float = 0.1, top_n_words: int = 10):
|
||||
self.diversity = diversity
|
||||
self.top_n_words = top_n_words
|
||||
|
||||
def extract_topics(
|
||||
self,
|
||||
topic_model,
|
||||
documents: pd.DataFrame,
|
||||
c_tf_idf: csr_matrix,
|
||||
topics: Mapping[str, List[Tuple[str, float]]],
|
||||
) -> Mapping[str, List[Tuple[str, float]]]:
|
||||
"""Extract topic representations.
|
||||
|
||||
Arguments:
|
||||
topic_model: The BERTopic model
|
||||
documents: Not used
|
||||
c_tf_idf: Not used
|
||||
topics: The candidate topics as calculated with c-TF-IDF
|
||||
|
||||
Returns:
|
||||
updated_topics: Updated topic representations
|
||||
"""
|
||||
if topic_model.embedding_model is None:
|
||||
warnings.warn(
|
||||
"MaximalMarginalRelevance can only be used BERTopic was instantiated"
|
||||
"with the `embedding_model` parameter."
|
||||
)
|
||||
return topics
|
||||
|
||||
updated_topics = {}
|
||||
for topic, topic_words in topics.items():
|
||||
words = [word[0] for word in topic_words]
|
||||
word_embeddings = topic_model._extract_embeddings(words, method="word", verbose=False)
|
||||
topic_embedding = topic_model._extract_embeddings(" ".join(words), method="word", verbose=False).reshape(
|
||||
1, -1
|
||||
)
|
||||
topic_words = mmr(
|
||||
topic_embedding,
|
||||
word_embeddings,
|
||||
words,
|
||||
self.diversity,
|
||||
self.top_n_words,
|
||||
)
|
||||
updated_topics[topic] = [(word, value) for word, value in topics[topic] if word in topic_words]
|
||||
return updated_topics
|
||||
|
||||
|
||||
def mmr(
|
||||
doc_embedding: np.ndarray,
|
||||
word_embeddings: np.ndarray,
|
||||
words: List[str],
|
||||
diversity: float = 0.1,
|
||||
top_n: int = 10,
|
||||
) -> List[str]:
|
||||
"""Maximal Marginal Relevance.
|
||||
|
||||
Arguments:
|
||||
doc_embedding: The document embeddings
|
||||
word_embeddings: The embeddings of the selected candidate keywords/phrases
|
||||
words: The selected candidate keywords/keyphrases
|
||||
diversity: The diversity of the selected embeddings.
|
||||
Values between 0 and 1.
|
||||
top_n: The top n items to return
|
||||
|
||||
Returns:
|
||||
List[str]: The selected keywords/keyphrases
|
||||
"""
|
||||
# Extract similarity within words, and between words and the document
|
||||
word_doc_similarity = cosine_similarity(word_embeddings, doc_embedding)
|
||||
word_similarity = cosine_similarity(word_embeddings)
|
||||
|
||||
# Initialize candidates and already choose best keyword/keyphras
|
||||
keywords_idx = [np.argmax(word_doc_similarity)]
|
||||
candidates_idx = [i for i in range(len(words)) if i != keywords_idx[0]]
|
||||
|
||||
for _ in range(top_n - 1):
|
||||
# Extract similarities within candidates and
|
||||
# between candidates and selected keywords/phrases
|
||||
candidate_similarities = word_doc_similarity[candidates_idx, :]
|
||||
target_similarities = np.max(word_similarity[candidates_idx][:, keywords_idx], axis=1)
|
||||
|
||||
# Calculate MMR
|
||||
mmr = (1 - diversity) * candidate_similarities - diversity * target_similarities.reshape(-1, 1)
|
||||
mmr_idx = candidates_idx[np.argmax(mmr)]
|
||||
|
||||
# Update keywords & candidates
|
||||
keywords_idx.append(mmr_idx)
|
||||
candidates_idx.remove(mmr_idx)
|
||||
|
||||
return [words[idx] for idx in keywords_idx]
|
||||
@@ -0,0 +1,274 @@
|
||||
import time
|
||||
import openai
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
from scipy.sparse import csr_matrix
|
||||
from typing import Mapping, List, Tuple, Any, Union, Callable
|
||||
from bertopic.representation._base import BaseRepresentation
|
||||
from bertopic.representation._utils import (
|
||||
retry_with_exponential_backoff,
|
||||
truncate_document,
|
||||
validate_truncate_document_parameters,
|
||||
)
|
||||
|
||||
|
||||
DEFAULT_CHAT_PROMPT = """You will extract a short topic label from given documents and keywords.
|
||||
Here are two examples of topics you created before:
|
||||
|
||||
# Example 1
|
||||
Sample texts from this topic:
|
||||
- Traditional diets in most cultures were primarily plant-based with a little meat on top, but with the rise of industrial style meat production and factory farming, meat has become a staple food.
|
||||
- Meat, but especially beef, is the worst food in terms of emissions.
|
||||
- Eating meat doesn't make you a bad person, not eating meat doesn't make you a good one.
|
||||
|
||||
Keywords: meat beef eat eating emissions steak food health processed chicken
|
||||
topic: Environmental impacts of eating meat
|
||||
|
||||
# Example 2
|
||||
Sample texts from this topic:
|
||||
- I have ordered the product weeks ago but it still has not arrived!
|
||||
- The website mentions that it only takes a couple of days to deliver but I still have not received mine.
|
||||
- I got a message stating that I received the monitor but that is not true!
|
||||
- It took a month longer to deliver than was advised...
|
||||
|
||||
Keywords: deliver weeks product shipping long delivery received arrived arrive week
|
||||
topic: Shipping and delivery issues
|
||||
|
||||
# Your task
|
||||
Sample texts from this topic:
|
||||
[DOCUMENTS]
|
||||
|
||||
Keywords: [KEYWORDS]
|
||||
|
||||
Based on the information above, extract a short topic label (three words at most) in the following format:
|
||||
topic: <topic_label>
|
||||
"""
|
||||
|
||||
DEFAULT_SYSTEM_PROMPT = "You are an assistant that extracts high-level topics from texts."
|
||||
|
||||
|
||||
class OpenAI(BaseRepresentation):
|
||||
r"""Using the OpenAI API to generate topic labels based
|
||||
on one of their Completion of ChatCompletion models.
|
||||
|
||||
For an overview see:
|
||||
https://platform.openai.com/docs/models
|
||||
|
||||
Arguments:
|
||||
client: A `openai.OpenAI` client
|
||||
model: Model to use within OpenAI, defaults to `"gpt-4o-mini"`.
|
||||
generator_kwargs: Kwargs passed to `openai.Completion.create`
|
||||
for fine-tuning the output.
|
||||
prompt: The prompt to be used in the model. If no prompt is given,
|
||||
`self.default_prompt_` is used instead.
|
||||
NOTE: Use `"[KEYWORDS]"` and `"[DOCUMENTS]"` in the prompt
|
||||
to decide where the keywords and documents need to be
|
||||
inserted.
|
||||
system_prompt: The system prompt to be used in the model. If no system prompt is given,
|
||||
`self.default_system_prompt_` is used instead.
|
||||
delay_in_seconds: The delay in seconds between consecutive prompts
|
||||
in order to prevent RateLimitErrors.
|
||||
exponential_backoff: Retry requests with a random exponential backoff.
|
||||
A short sleep is used when a rate limit error is hit,
|
||||
then the requests is retried. Increase the sleep length
|
||||
if errors are hit until 10 unsuccessful requests.
|
||||
If True, overrides `delay_in_seconds`.
|
||||
nr_docs: The number of documents to pass to OpenAI if a prompt
|
||||
with the `["DOCUMENTS"]` tag is used.
|
||||
diversity: The diversity of documents to pass to OpenAI.
|
||||
Accepts values between 0 and 1. A higher
|
||||
values results in passing more diverse documents
|
||||
whereas lower values passes more similar documents.
|
||||
doc_length: The maximum length of each document. If a document is longer,
|
||||
it will be truncated. If None, the entire document is passed.
|
||||
tokenizer: The tokenizer used to calculate to split the document into segments
|
||||
used to count the length of a document.
|
||||
* If tokenizer is 'char', then the document is split up
|
||||
into characters which are counted to adhere to `doc_length`
|
||||
* If tokenizer is 'whitespace', the document is split up
|
||||
into words separated by whitespaces. These words are counted
|
||||
and truncated depending on `doc_length`
|
||||
* If tokenizer is 'vectorizer', then the internal CountVectorizer
|
||||
is used to tokenize the document. These tokens are counted
|
||||
and truncated depending on `doc_length`
|
||||
* If tokenizer is a callable, then that callable is used to tokenize
|
||||
the document. These tokens are counted and truncated depending
|
||||
on `doc_length`
|
||||
|
||||
Usage:
|
||||
|
||||
To use this, you will need to install the openai package first:
|
||||
|
||||
`pip install openai`
|
||||
|
||||
Then, get yourself an API key and use OpenAI's API as follows:
|
||||
|
||||
```python
|
||||
import openai
|
||||
from bertopic.representation import OpenAI
|
||||
from bertopic import BERTopic
|
||||
|
||||
# Create your representation model
|
||||
client = openai.OpenAI(api_key=MY_API_KEY)
|
||||
representation_model = OpenAI(client, delay_in_seconds=5)
|
||||
|
||||
# Use the representation model in BERTopic on top of the default pipeline
|
||||
topic_model = BERTopic(representation_model=representation_model)
|
||||
```
|
||||
|
||||
You can also use a custom prompt:
|
||||
|
||||
```python
|
||||
prompt = "I have the following documents: [DOCUMENTS] \nThese documents are about the following topic: '"
|
||||
representation_model = OpenAI(client, prompt=prompt, delay_in_seconds=5)
|
||||
```
|
||||
|
||||
To choose a model:
|
||||
|
||||
```python
|
||||
representation_model = OpenAI(client, model="gpt-4o-mini", delay_in_seconds=10)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client,
|
||||
model: str = "gpt-4o-mini",
|
||||
prompt: str = None,
|
||||
system_prompt: str = None,
|
||||
generator_kwargs: Mapping[str, Any] = {},
|
||||
delay_in_seconds: float = None,
|
||||
exponential_backoff: bool = False,
|
||||
nr_docs: int = 4,
|
||||
diversity: float = None,
|
||||
doc_length: int = None,
|
||||
tokenizer: Union[str, Callable] = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.client = client
|
||||
self.model = model
|
||||
|
||||
if prompt is None:
|
||||
self.prompt = DEFAULT_CHAT_PROMPT
|
||||
else:
|
||||
self.prompt = prompt
|
||||
|
||||
if system_prompt is None:
|
||||
self.system_prompt = DEFAULT_SYSTEM_PROMPT
|
||||
else:
|
||||
self.system_prompt = system_prompt
|
||||
|
||||
self.default_prompt_ = DEFAULT_CHAT_PROMPT
|
||||
self.default_system_prompt_ = DEFAULT_SYSTEM_PROMPT
|
||||
self.delay_in_seconds = delay_in_seconds
|
||||
self.exponential_backoff = exponential_backoff
|
||||
self.nr_docs = nr_docs
|
||||
self.diversity = diversity
|
||||
self.doc_length = doc_length
|
||||
self.tokenizer = tokenizer
|
||||
validate_truncate_document_parameters(self.tokenizer, self.doc_length)
|
||||
|
||||
self.prompts_ = []
|
||||
|
||||
self.generator_kwargs = generator_kwargs
|
||||
if self.generator_kwargs.get("model"):
|
||||
self.model = generator_kwargs.get("model")
|
||||
del self.generator_kwargs["model"]
|
||||
if self.generator_kwargs.get("prompt"):
|
||||
del self.generator_kwargs["prompt"]
|
||||
if not self.generator_kwargs.get("stop"):
|
||||
self.generator_kwargs["stop"] = "\n"
|
||||
|
||||
def extract_topics(
|
||||
self,
|
||||
topic_model,
|
||||
documents: pd.DataFrame,
|
||||
c_tf_idf: csr_matrix,
|
||||
topics: Mapping[str, List[Tuple[str, float]]],
|
||||
) -> Mapping[str, List[Tuple[str, float]]]:
|
||||
"""Extract topics.
|
||||
|
||||
Arguments:
|
||||
topic_model: A BERTopic model
|
||||
documents: All input documents
|
||||
c_tf_idf: The topic c-TF-IDF representation
|
||||
topics: The candidate topics as calculated with c-TF-IDF
|
||||
|
||||
Returns:
|
||||
updated_topics: Updated topic representations
|
||||
"""
|
||||
# Extract the top n representative documents per topic
|
||||
repr_docs_mappings, _, _, _ = topic_model._extract_representative_docs(
|
||||
c_tf_idf, documents, topics, 500, self.nr_docs, self.diversity
|
||||
)
|
||||
|
||||
# Generate using OpenAI's Language Model
|
||||
updated_topics = {}
|
||||
for topic, docs in tqdm(repr_docs_mappings.items(), disable=not topic_model.verbose):
|
||||
truncated_docs = [truncate_document(topic_model, self.doc_length, self.tokenizer, doc) for doc in docs]
|
||||
prompt = self._create_prompt(truncated_docs, topic, topics)
|
||||
self.prompts_.append(prompt)
|
||||
|
||||
# Delay
|
||||
if self.delay_in_seconds:
|
||||
time.sleep(self.delay_in_seconds)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": self.system_prompt},
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
kwargs = {
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
**self.generator_kwargs,
|
||||
}
|
||||
if self.exponential_backoff:
|
||||
response = chat_completions_with_backoff(self.client, **kwargs)
|
||||
else:
|
||||
response = self.client.chat.completions.create(**kwargs)
|
||||
|
||||
# Check whether content was actually generated
|
||||
# Addresses #1570 for potential issues with OpenAI's content filter
|
||||
# Addresses #2176 for potential issues when openAI returns a None type object
|
||||
if response and hasattr(response.choices[0].message, "content"):
|
||||
label = response.choices[0].message.content.strip().replace("topic: ", "")
|
||||
else:
|
||||
label = "No label returned"
|
||||
|
||||
updated_topics[topic] = [(label, 1)]
|
||||
|
||||
return updated_topics
|
||||
|
||||
def _create_prompt(self, docs, topic, topics):
|
||||
keywords = list(zip(*topics[topic]))[0]
|
||||
|
||||
# Use the Default Chat Prompt
|
||||
if self.prompt == DEFAULT_CHAT_PROMPT:
|
||||
prompt = self.prompt.replace("[KEYWORDS]", ", ".join(keywords))
|
||||
prompt = self._replace_documents(prompt, docs)
|
||||
|
||||
# Use a custom prompt that leverages keywords, documents or both using
|
||||
# custom tags, namely [KEYWORDS] and [DOCUMENTS] respectively
|
||||
else:
|
||||
prompt = self.prompt
|
||||
if "[KEYWORDS]" in prompt:
|
||||
prompt = prompt.replace("[KEYWORDS]", ", ".join(keywords))
|
||||
if "[DOCUMENTS]" in prompt:
|
||||
prompt = self._replace_documents(prompt, docs)
|
||||
|
||||
return prompt
|
||||
|
||||
@staticmethod
|
||||
def _replace_documents(prompt, docs):
|
||||
to_replace = ""
|
||||
for doc in docs:
|
||||
to_replace += f"- {doc}\n"
|
||||
prompt = prompt.replace("[DOCUMENTS]", to_replace)
|
||||
return prompt
|
||||
|
||||
|
||||
def chat_completions_with_backoff(client, **kwargs):
|
||||
return retry_with_exponential_backoff(
|
||||
client.chat.completions.create,
|
||||
errors=(openai.RateLimitError,),
|
||||
)(**kwargs)
|
||||
@@ -0,0 +1,161 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
import spacy
|
||||
from spacy.matcher import Matcher
|
||||
from spacy.language import Language
|
||||
|
||||
from packaging import version
|
||||
from scipy.sparse import csr_matrix
|
||||
from typing import List, Mapping, Tuple, Union
|
||||
from sklearn import __version__ as sklearn_version
|
||||
from bertopic.representation._base import BaseRepresentation
|
||||
|
||||
|
||||
class PartOfSpeech(BaseRepresentation):
|
||||
"""Extract Topic Keywords based on their Part-of-Speech.
|
||||
|
||||
DEFAULT_PATTERNS = [
|
||||
[{'POS': 'ADJ'}, {'POS': 'NOUN'}],
|
||||
[{'POS': 'NOUN'}],
|
||||
[{'POS': 'ADJ'}]
|
||||
]
|
||||
|
||||
From candidate topics, as extracted with c-TF-IDF,
|
||||
find documents that contain keywords found in the
|
||||
candidate topics. These candidate documents then
|
||||
serve as the representative set of documents from
|
||||
which the Spacy model can extract a set of candidate
|
||||
keywords for each topic.
|
||||
|
||||
These candidate keywords are first judged by whether
|
||||
they fall within the DEFAULT_PATTERNS or the user-defined
|
||||
pattern. Then, the resulting keywords are sorted by
|
||||
their respective c-TF-IDF values.
|
||||
|
||||
Arguments:
|
||||
model: The Spacy model to use
|
||||
top_n_words: The top n words to extract
|
||||
pos_patterns: Patterns for Spacy to use.
|
||||
See https://spacy.io/usage/rule-based-matching
|
||||
|
||||
Usage:
|
||||
|
||||
```python
|
||||
from bertopic.representation import PartOfSpeech
|
||||
from bertopic import BERTopic
|
||||
|
||||
# Create your representation model
|
||||
representation_model = PartOfSpeech("en_core_web_sm")
|
||||
|
||||
# Use the representation model in BERTopic on top of the default pipeline
|
||||
topic_model = BERTopic(representation_model=representation_model)
|
||||
```
|
||||
|
||||
You can define custom POS patterns to be extracted:
|
||||
|
||||
```python
|
||||
pos_patterns = [
|
||||
[{'POS': 'ADJ'}, {'POS': 'NOUN'}],
|
||||
[{'POS': 'NOUN'}], [{'POS': 'ADJ'}]
|
||||
]
|
||||
representation_model = PartOfSpeech("en_core_web_sm", pos_patterns=pos_patterns)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Union[str, Language] = "en_core_web_sm",
|
||||
top_n_words: int = 10,
|
||||
pos_patterns: List[str] = None,
|
||||
):
|
||||
if isinstance(model, str):
|
||||
self.model = spacy.load(model)
|
||||
elif isinstance(model, Language):
|
||||
self.model = model
|
||||
else:
|
||||
raise ValueError(
|
||||
"Make sure that the Spacy model that you"
|
||||
"pass is either a string referring to a"
|
||||
"Spacy model or a Spacy nlp object."
|
||||
)
|
||||
|
||||
self.top_n_words = top_n_words
|
||||
|
||||
if pos_patterns is None:
|
||||
self.pos_patterns = [
|
||||
[{"POS": "ADJ"}, {"POS": "NOUN"}],
|
||||
[{"POS": "NOUN"}],
|
||||
[{"POS": "ADJ"}],
|
||||
]
|
||||
else:
|
||||
self.pos_patterns = pos_patterns
|
||||
|
||||
def extract_topics(
|
||||
self,
|
||||
topic_model,
|
||||
documents: pd.DataFrame,
|
||||
c_tf_idf: csr_matrix,
|
||||
topics: Mapping[str, List[Tuple[str, float]]],
|
||||
) -> Mapping[str, List[Tuple[str, float]]]:
|
||||
"""Extract topics.
|
||||
|
||||
Arguments:
|
||||
topic_model: A BERTopic model
|
||||
documents: All input documents
|
||||
c_tf_idf: Not used
|
||||
topics: The candidate topics as calculated with c-TF-IDF
|
||||
|
||||
Returns:
|
||||
updated_topics: Updated topic representations
|
||||
"""
|
||||
matcher = Matcher(self.model.vocab)
|
||||
matcher.add("Pattern", self.pos_patterns)
|
||||
|
||||
candidate_topics = {}
|
||||
for topic, values in topics.items():
|
||||
keywords = list(zip(*values))[0]
|
||||
|
||||
# Extract candidate documents
|
||||
candidate_documents = []
|
||||
for keyword in keywords:
|
||||
selection = documents.loc[documents.Topic == topic, :]
|
||||
selection = selection.loc[selection.Document.str.contains(keyword, regex=False), "Document"]
|
||||
if len(selection) > 0:
|
||||
for document in selection[:2]:
|
||||
candidate_documents.append(document)
|
||||
candidate_documents = list(set(candidate_documents))
|
||||
|
||||
# Extract keywords
|
||||
docs_pipeline = self.model.pipe(candidate_documents)
|
||||
updated_keywords = []
|
||||
for doc in docs_pipeline:
|
||||
matches = matcher(doc)
|
||||
for _, start, end in matches:
|
||||
updated_keywords.append(doc[start:end].text)
|
||||
candidate_topics[topic] = list(set(updated_keywords))
|
||||
|
||||
# Scikit-Learn Deprecation: get_feature_names is deprecated in 1.0
|
||||
# and will be removed in 1.2. Please use get_feature_names_out instead.
|
||||
if version.parse(sklearn_version) >= version.parse("1.0.0"):
|
||||
words = list(topic_model.vectorizer_model.get_feature_names_out())
|
||||
else:
|
||||
words = list(topic_model.vectorizer_model.get_feature_names())
|
||||
|
||||
# Match updated keywords with c-TF-IDF values
|
||||
words_lookup = dict(zip(words, range(len(words))))
|
||||
updated_topics = {topic: [] for topic in topics.keys()}
|
||||
|
||||
for topic, candidate_keywords in candidate_topics.items():
|
||||
word_indices = np.sort(
|
||||
[words_lookup.get(keyword) for keyword in candidate_keywords if keyword in words_lookup]
|
||||
)
|
||||
vals = topic_model.c_tf_idf_[:, word_indices][topic + topic_model._outliers]
|
||||
indices = np.argsort(np.array(vals.todense().reshape(1, -1))[0])[-self.top_n_words :][::-1]
|
||||
vals = np.sort(np.array(vals.todense().reshape(1, -1))[0])[-self.top_n_words :][::-1]
|
||||
topic_words = [(words[word_indices[index]], val) for index, val in zip(indices, vals)]
|
||||
updated_topics[topic] = topic_words
|
||||
if len(updated_topics[topic]) < self.top_n_words:
|
||||
updated_topics[topic] += [("", 0) for _ in range(self.top_n_words - len(updated_topics[topic]))]
|
||||
|
||||
return updated_topics
|
||||
@@ -0,0 +1,188 @@
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
from scipy.sparse import csr_matrix
|
||||
from transformers import pipeline, set_seed
|
||||
from transformers.pipelines.base import Pipeline
|
||||
from typing import Mapping, List, Tuple, Any, Union, Callable
|
||||
from bertopic.representation._base import BaseRepresentation
|
||||
from bertopic.representation._utils import truncate_document, validate_truncate_document_parameters
|
||||
|
||||
|
||||
DEFAULT_PROMPT = """
|
||||
I have a topic described by the following keywords: [KEYWORDS].
|
||||
The name of this topic is:
|
||||
"""
|
||||
|
||||
|
||||
class TextGeneration(BaseRepresentation):
|
||||
"""Text2Text or text generation with transformers.
|
||||
|
||||
Arguments:
|
||||
model: A transformers pipeline that should be initialized as "text-generation"
|
||||
for gpt-like models or "text2text-generation" for T5-like models.
|
||||
For example, `pipeline('text-generation', model='gpt2')`. If a string
|
||||
is passed, "text-generation" will be selected by default.
|
||||
prompt: The prompt to be used in the model. If no prompt is given,
|
||||
`self.default_prompt_` is used instead.
|
||||
NOTE: Use `"[KEYWORDS]"` and `"[DOCUMENTS]"` in the prompt
|
||||
to decide where the keywords and documents need to be
|
||||
inserted.
|
||||
pipeline_kwargs: Kwargs that you can pass to the transformers.pipeline
|
||||
when it is called.
|
||||
random_state: A random state to be passed to `transformers.set_seed`
|
||||
nr_docs: The number of documents to pass to OpenAI if a prompt
|
||||
with the `["DOCUMENTS"]` tag is used.
|
||||
diversity: The diversity of documents to pass to OpenAI.
|
||||
Accepts values between 0 and 1. A higher
|
||||
values results in passing more diverse documents
|
||||
whereas lower values passes more similar documents.
|
||||
doc_length: The maximum length of each document. If a document is longer,
|
||||
it will be truncated. If None, the entire document is passed.
|
||||
tokenizer: The tokenizer used to calculate to split the document into segments
|
||||
used to count the length of a document.
|
||||
* If tokenizer is 'char', then the document is split up
|
||||
into characters which are counted to adhere to `doc_length`
|
||||
* If tokenizer is 'whitespace', the document is split up
|
||||
into words separated by whitespaces. These words are counted
|
||||
and truncated depending on `doc_length`
|
||||
* If tokenizer is 'vectorizer', then the internal CountVectorizer
|
||||
is used to tokenize the document. These tokens are counted
|
||||
and truncated depending on `doc_length`
|
||||
* If tokenizer is a callable, then that callable is used to tokenize
|
||||
the document. These tokens are counted and truncated depending
|
||||
on `doc_length`
|
||||
|
||||
Usage:
|
||||
|
||||
To use a gpt-like model:
|
||||
|
||||
```python
|
||||
from bertopic.representation import TextGeneration
|
||||
from bertopic import BERTopic
|
||||
|
||||
# Create your representation model
|
||||
generator = pipeline('text-generation', model='gpt2')
|
||||
representation_model = TextGeneration(generator)
|
||||
|
||||
# Use the representation model in BERTopic on top of the default pipeline
|
||||
topic_model = BERTo pic(representation_model=representation_model)
|
||||
```
|
||||
|
||||
You can use a custom prompt and decide where the keywords should
|
||||
be inserted by using the `[KEYWORDS]` or documents with thte `[DOCUMENTS]` tag:
|
||||
|
||||
```python
|
||||
from bertopic.representation import TextGeneration
|
||||
|
||||
prompt = "I have a topic described by the following keywords: [KEYWORDS]. Based on the previous keywords, what is this topic about?""
|
||||
|
||||
# Create your representation model
|
||||
generator = pipeline('text2text-generation', model='google/flan-t5-base')
|
||||
representation_model = TextGeneration(generator)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Union[str, pipeline],
|
||||
prompt: str = None,
|
||||
pipeline_kwargs: Mapping[str, Any] = {},
|
||||
random_state: int = 42,
|
||||
nr_docs: int = 4,
|
||||
diversity: float = None,
|
||||
doc_length: int = None,
|
||||
tokenizer: Union[str, Callable] = None,
|
||||
):
|
||||
self.random_state = random_state
|
||||
set_seed(random_state)
|
||||
if isinstance(model, str):
|
||||
self.model = pipeline("text-generation", model=model)
|
||||
elif isinstance(model, Pipeline):
|
||||
self.model = model
|
||||
else:
|
||||
raise ValueError(
|
||||
"Make sure that the HF model that you"
|
||||
"pass is either a string referring to a"
|
||||
"HF model or a `transformers.pipeline` object."
|
||||
)
|
||||
self.prompt = prompt if prompt is not None else DEFAULT_PROMPT
|
||||
self.default_prompt_ = DEFAULT_PROMPT
|
||||
self.pipeline_kwargs = pipeline_kwargs
|
||||
self.nr_docs = nr_docs
|
||||
self.diversity = diversity
|
||||
self.doc_length = doc_length
|
||||
self.tokenizer = tokenizer
|
||||
validate_truncate_document_parameters(self.tokenizer, self.doc_length)
|
||||
|
||||
self.prompts_ = []
|
||||
|
||||
def extract_topics(
|
||||
self,
|
||||
topic_model,
|
||||
documents: pd.DataFrame,
|
||||
c_tf_idf: csr_matrix,
|
||||
topics: Mapping[str, List[Tuple[str, float]]],
|
||||
) -> Mapping[str, List[Tuple[str, float]]]:
|
||||
"""Extract topic representations and return a single label.
|
||||
|
||||
Arguments:
|
||||
topic_model: A BERTopic model
|
||||
documents: Not used
|
||||
c_tf_idf: Not used
|
||||
topics: The candidate topics as calculated with c-TF-IDF
|
||||
|
||||
Returns:
|
||||
updated_topics: Updated topic representations
|
||||
"""
|
||||
# Extract the top 4 representative documents per topic
|
||||
if self.prompt != DEFAULT_PROMPT and "[DOCUMENTS]" in self.prompt:
|
||||
repr_docs_mappings, _, _, _ = topic_model._extract_representative_docs(
|
||||
c_tf_idf, documents, topics, 500, self.nr_docs, self.diversity
|
||||
)
|
||||
else:
|
||||
repr_docs_mappings = {topic: None for topic in topics.keys()}
|
||||
|
||||
updated_topics = {}
|
||||
for topic, docs in tqdm(repr_docs_mappings.items(), disable=not topic_model.verbose):
|
||||
# Prepare prompt
|
||||
truncated_docs = (
|
||||
[truncate_document(topic_model, self.doc_length, self.tokenizer, doc) for doc in docs]
|
||||
if docs is not None
|
||||
else docs
|
||||
)
|
||||
prompt = self._create_prompt(truncated_docs, topic, topics)
|
||||
self.prompts_.append(prompt)
|
||||
|
||||
# Extract result from generator and use that as label
|
||||
topic_description = self.model(prompt, **self.pipeline_kwargs)
|
||||
topic_description = [
|
||||
(description["generated_text"].replace(prompt, ""), 1) for description in topic_description
|
||||
]
|
||||
|
||||
if len(topic_description) < 10:
|
||||
topic_description += [("", 0) for _ in range(10 - len(topic_description))]
|
||||
|
||||
updated_topics[topic] = topic_description
|
||||
|
||||
return updated_topics
|
||||
|
||||
def _create_prompt(self, docs, topic, topics):
|
||||
keywords = ", ".join(list(zip(*topics[topic]))[0])
|
||||
|
||||
# Use the default prompt and replace keywords
|
||||
if self.prompt == DEFAULT_PROMPT:
|
||||
prompt = self.prompt.replace("[KEYWORDS]", keywords)
|
||||
|
||||
# Use a prompt that leverages either keywords or documents in
|
||||
# a custom location
|
||||
else:
|
||||
prompt = self.prompt
|
||||
if "[KEYWORDS]" in prompt:
|
||||
prompt = prompt.replace("[KEYWORDS]", keywords)
|
||||
if "[DOCUMENTS]" in prompt:
|
||||
to_replace = ""
|
||||
for doc in docs:
|
||||
to_replace += f"- {doc}\n"
|
||||
prompt = prompt.replace("[DOCUMENTS]", to_replace)
|
||||
|
||||
return prompt
|
||||
@@ -0,0 +1,113 @@
|
||||
import random
|
||||
import time
|
||||
from typing import Union
|
||||
|
||||
|
||||
def truncate_document(topic_model, doc_length: Union[int, None], tokenizer: Union[str, callable], document: str) -> str:
|
||||
"""Truncate a document to a certain length.
|
||||
|
||||
If you want to add a custom tokenizer, then it will need to have a `decode` and
|
||||
`encode` method. An example would be the following custom tokenizer:
|
||||
|
||||
```python
|
||||
class Tokenizer:
|
||||
'A custom tokenizer that splits on commas'
|
||||
def encode(self, doc):
|
||||
return doc.split(",")
|
||||
|
||||
def decode(self, doc_chunks):
|
||||
return ",".join(doc_chunks)
|
||||
```
|
||||
|
||||
You can use this tokenizer by passing it to the `tokenizer` parameter.
|
||||
|
||||
Arguments:
|
||||
topic_model: A BERTopic model
|
||||
doc_length: The maximum length of each document. If a document is longer,
|
||||
it will be truncated. If None, the entire document is passed.
|
||||
tokenizer: The tokenizer used to calculate to split the document into segments
|
||||
used to count the length of a document.
|
||||
* If tokenizer is 'char', then the document is split up
|
||||
into characters which are counted to adhere to `doc_length`
|
||||
* If tokenizer is 'whitespace', the document is split up
|
||||
into words separated by whitespaces. These words are counted
|
||||
and truncated depending on `doc_length`
|
||||
* If tokenizer is 'vectorizer', then the internal CountVectorizer
|
||||
is used to tokenize the document. These tokens are counted
|
||||
and truncated depending on `doc_length`. They are decoded with
|
||||
whitespaces.
|
||||
* If tokenizer is a callable, then that callable is used to tokenize
|
||||
the document. These tokens are counted and truncated depending
|
||||
on `doc_length`
|
||||
document: A single document
|
||||
|
||||
Returns:
|
||||
truncated_document: A truncated document
|
||||
"""
|
||||
if doc_length is not None:
|
||||
if tokenizer == "char":
|
||||
truncated_document = document[:doc_length]
|
||||
elif tokenizer == "whitespace":
|
||||
truncated_document = " ".join(document.split()[:doc_length])
|
||||
elif tokenizer == "vectorizer":
|
||||
tokenizer = topic_model.vectorizer_model.build_tokenizer()
|
||||
truncated_document = " ".join(tokenizer(document)[:doc_length])
|
||||
elif hasattr(tokenizer, "encode") and hasattr(tokenizer, "decode"):
|
||||
encoded_document = tokenizer.encode(document)
|
||||
truncated_document = tokenizer.decode(encoded_document[:doc_length])
|
||||
return truncated_document
|
||||
return document
|
||||
|
||||
|
||||
def validate_truncate_document_parameters(tokenizer, doc_length) -> Union[None, ValueError]:
|
||||
"""Validates parameters that are used in the function `truncate_document`."""
|
||||
if tokenizer is None and doc_length is not None:
|
||||
raise ValueError(
|
||||
"Please select from one of the valid options for the `tokenizer` parameter: \n"
|
||||
"{'char', 'whitespace', 'vectorizer'} \n"
|
||||
"If `tokenizer` is of type callable ensure it has methods to encode and decode a document \n"
|
||||
)
|
||||
elif tokenizer is not None and doc_length is None:
|
||||
raise ValueError("If `tokenizer` is provided, `doc_length` of type int must be provided as well.")
|
||||
|
||||
|
||||
def retry_with_exponential_backoff(
|
||||
func,
|
||||
initial_delay: float = 1,
|
||||
exponential_base: float = 2,
|
||||
jitter: bool = True,
|
||||
max_retries: int = 10,
|
||||
errors: tuple = None,
|
||||
):
|
||||
"""Retry a function with exponential backoff."""
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
# Initialize variables
|
||||
num_retries = 0
|
||||
delay = initial_delay
|
||||
|
||||
# Loop until a successful response or max_retries is hit or an exception is raised
|
||||
while True:
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
# Retry on specific errors
|
||||
except errors:
|
||||
# Increment retries
|
||||
num_retries += 1
|
||||
|
||||
# Check if max retries has been reached
|
||||
if num_retries > max_retries:
|
||||
raise Exception(f"Maximum number of retries ({max_retries}) exceeded.")
|
||||
|
||||
# Increment the delay
|
||||
delay *= exponential_base * (1 + jitter * random.random())
|
||||
|
||||
# Sleep for the delay
|
||||
time.sleep(delay)
|
||||
|
||||
# Raise exceptions for any errors not specified
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
return wrapper
|
||||
@@ -0,0 +1,274 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
from scipy.sparse import csr_matrix
|
||||
from typing import Mapping, List, Tuple, Union
|
||||
from transformers.pipelines import Pipeline, pipeline
|
||||
|
||||
from bertopic.representation._mmr import mmr
|
||||
from bertopic.representation._base import BaseRepresentation
|
||||
|
||||
|
||||
class VisualRepresentation(BaseRepresentation):
|
||||
"""From a collection of representative documents, extract
|
||||
images to represent topics. These topics are represented by a
|
||||
collage of images.
|
||||
|
||||
Arguments:
|
||||
nr_repr_images: Number of representative images to extract
|
||||
nr_samples: The number of candidate documents to extract per cluster.
|
||||
image_height: The height of the resulting collage
|
||||
image_square: Whether to resize each image in the collage
|
||||
to a square. This can be visually more appealing
|
||||
if all input images are all almost squares.
|
||||
image_to_text_model: The model to caption images.
|
||||
batch_size: The number of images to pass to the
|
||||
`image_to_text_model`.
|
||||
|
||||
Usage:
|
||||
|
||||
```python
|
||||
from bertopic.representation import VisualRepresentation
|
||||
from bertopic import BERTopic
|
||||
|
||||
# The visual representation is typically not a core representation
|
||||
# and is advised to pass to BERTopic as an additional aspect.
|
||||
# Aspects can be labeled with dictionaries as shown below:
|
||||
representation_model = {
|
||||
"Visual_Aspect": VisualRepresentation()
|
||||
}
|
||||
|
||||
# Use the representation model in BERTopic as a separate aspect
|
||||
topic_model = BERTopic(representation_model=representation_model)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
nr_repr_images: int = 9,
|
||||
nr_samples: int = 500,
|
||||
image_height: Tuple[int, int] = 600,
|
||||
image_squares: bool = False,
|
||||
image_to_text_model: Union[str, Pipeline] = None,
|
||||
batch_size: int = 32,
|
||||
):
|
||||
self.nr_repr_images = nr_repr_images
|
||||
self.nr_samples = nr_samples
|
||||
self.image_height = image_height
|
||||
self.image_squares = image_squares
|
||||
|
||||
# Text-to-image model
|
||||
if isinstance(image_to_text_model, Pipeline):
|
||||
self.image_to_text_model = image_to_text_model
|
||||
elif isinstance(image_to_text_model, str):
|
||||
self.image_to_text_model = pipeline("image-to-text", model=image_to_text_model)
|
||||
elif image_to_text_model is None:
|
||||
self.image_to_text_model = None
|
||||
else:
|
||||
raise ValueError(
|
||||
"Please select a correct transformers pipeline. For example:"
|
||||
"pipeline('image-to-text', model='nlpconnect/vit-gpt2-image-captioning')"
|
||||
)
|
||||
self.batch_size = batch_size
|
||||
|
||||
def extract_topics(
|
||||
self,
|
||||
topic_model,
|
||||
documents: pd.DataFrame,
|
||||
c_tf_idf: csr_matrix,
|
||||
topics: Mapping[str, List[Tuple[str, float]]],
|
||||
) -> Mapping[str, List[Tuple[str, float]]]:
|
||||
"""Extract topics.
|
||||
|
||||
Arguments:
|
||||
topic_model: A BERTopic model
|
||||
documents: All input documents
|
||||
c_tf_idf: The topic c-TF-IDF representation
|
||||
topics: The candidate topics as calculated with c-TF-IDF
|
||||
|
||||
Returns:
|
||||
representative_images: Representative images per topic
|
||||
"""
|
||||
# Extract image ids of most representative documents
|
||||
images = documents["Image"].values.tolist()
|
||||
(_, _, _, repr_docs_ids) = topic_model._extract_representative_docs(
|
||||
c_tf_idf,
|
||||
documents,
|
||||
topics,
|
||||
nr_samples=self.nr_samples,
|
||||
nr_repr_docs=self.nr_repr_images,
|
||||
)
|
||||
unique_topics = sorted(list(topics.keys()))
|
||||
|
||||
# Combine representative images into a single representation
|
||||
representative_images = {}
|
||||
for topic in tqdm(unique_topics):
|
||||
# Get and order represetnative images
|
||||
sliced_examplars = repr_docs_ids[topic + topic_model._outliers]
|
||||
sliced_examplars = [sliced_examplars[i : i + 3] for i in range(0, len(sliced_examplars), 3)]
|
||||
images_to_combine = [
|
||||
[
|
||||
Image.open(images[index]) if isinstance(images[index], str) else images[index]
|
||||
for index in sub_indices
|
||||
]
|
||||
for sub_indices in sliced_examplars
|
||||
]
|
||||
|
||||
# Concatenate representative images
|
||||
representative_image = get_concat_tile_resize(images_to_combine, self.image_height, self.image_squares)
|
||||
representative_images[topic] = representative_image
|
||||
|
||||
# Make sure to properly close images
|
||||
if isinstance(images[0], str):
|
||||
for image_list in images_to_combine:
|
||||
for image in image_list:
|
||||
image.close()
|
||||
|
||||
return representative_images
|
||||
|
||||
def _convert_image_to_text(self, images: List[str], verbose: bool = False) -> List[str]:
|
||||
"""Convert a list of images to captions.
|
||||
|
||||
Arguments:
|
||||
images: A list of images or words to be converted to text.
|
||||
verbose: Controls the verbosity of the process
|
||||
|
||||
Returns:
|
||||
List of captions
|
||||
"""
|
||||
# Batch-wise image conversion
|
||||
if self.batch_size is not None:
|
||||
documents = []
|
||||
for batch in tqdm(self._chunks(images), disable=not verbose):
|
||||
outputs = self.image_to_text_model(batch)
|
||||
captions = [output[0]["generated_text"] for output in outputs]
|
||||
documents.extend(captions)
|
||||
|
||||
# Convert images to text
|
||||
else:
|
||||
outputs = self.image_to_text_model(images)
|
||||
documents = [output[0]["generated_text"] for output in outputs]
|
||||
|
||||
return documents
|
||||
|
||||
def image_to_text(self, documents: pd.DataFrame, embeddings: np.ndarray) -> pd.DataFrame:
|
||||
"""Convert images to text."""
|
||||
# Create image topic embeddings
|
||||
topics = documents.Topic.values.tolist()
|
||||
images = documents.Image.values.tolist()
|
||||
df = pd.DataFrame(np.hstack([np.array(topics).reshape(-1, 1), embeddings]))
|
||||
image_topic_embeddings = df.groupby(0).mean().values
|
||||
|
||||
# Extract image centroids
|
||||
image_centroids = {}
|
||||
unique_topics = sorted(list(set(topics)))
|
||||
for topic, topic_embedding in zip(unique_topics, image_topic_embeddings):
|
||||
indices = np.array([index for index, t in enumerate(topics) if t == topic])
|
||||
top_n = min([self.nr_repr_images, len(indices)])
|
||||
indices = mmr(
|
||||
topic_embedding.reshape(1, -1),
|
||||
embeddings[indices],
|
||||
indices,
|
||||
top_n=top_n,
|
||||
diversity=0.1,
|
||||
)
|
||||
image_centroids[topic] = indices
|
||||
|
||||
# Extract documents
|
||||
documents = pd.DataFrame(columns=["Document", "ID", "Topic", "Image"])
|
||||
current_id = 0
|
||||
for topic, image_ids in tqdm(image_centroids.items()):
|
||||
selected_images = [
|
||||
Image.open(images[index]) if isinstance(images[index], str) else images[index] for index in image_ids
|
||||
]
|
||||
text = self._convert_image_to_text(selected_images)
|
||||
|
||||
for doc, image_id in zip(text, image_ids):
|
||||
documents.loc[len(documents), :] = [
|
||||
doc,
|
||||
current_id,
|
||||
topic,
|
||||
images[image_id],
|
||||
]
|
||||
current_id += 1
|
||||
|
||||
# Properly close images
|
||||
if isinstance(images[image_ids[0]], str):
|
||||
for image in selected_images:
|
||||
image.close()
|
||||
|
||||
return documents
|
||||
|
||||
def _chunks(self, images):
|
||||
for i in range(0, len(images), self.batch_size):
|
||||
yield images[i : i + self.batch_size]
|
||||
|
||||
|
||||
def get_concat_h_multi_resize(im_list):
|
||||
"""Code adapted from: https://note.nkmk.me/en/python-pillow-concat-images/."""
|
||||
min_height = min(im.height for im in im_list)
|
||||
min_height = max(im.height for im in im_list)
|
||||
im_list_resize = []
|
||||
for im in im_list:
|
||||
im.resize((int(im.width * min_height / im.height), min_height), resample=0)
|
||||
im_list_resize.append(im)
|
||||
|
||||
total_width = sum(im.width for im in im_list_resize)
|
||||
dst = Image.new("RGB", (total_width, min_height), (255, 255, 255))
|
||||
pos_x = 0
|
||||
for im in im_list_resize:
|
||||
dst.paste(im, (pos_x, 0))
|
||||
pos_x += im.width
|
||||
return dst
|
||||
|
||||
|
||||
def get_concat_v_multi_resize(im_list):
|
||||
"""Code adapted from: https://note.nkmk.me/en/python-pillow-concat-images/."""
|
||||
min_width = min(im.width for im in im_list)
|
||||
min_width = max(im.width for im in im_list)
|
||||
im_list_resize = [im.resize((min_width, int(im.height * min_width / im.width)), resample=0) for im in im_list]
|
||||
total_height = sum(im.height for im in im_list_resize)
|
||||
dst = Image.new("RGB", (min_width, total_height), (255, 255, 255))
|
||||
pos_y = 0
|
||||
for im in im_list_resize:
|
||||
dst.paste(im, (0, pos_y))
|
||||
pos_y += im.height
|
||||
return dst
|
||||
|
||||
|
||||
def get_concat_tile_resize(im_list_2d, image_height=600, image_squares=False):
|
||||
"""Code adapted from: https://note.nkmk.me/en/python-pillow-concat-images/."""
|
||||
images = [[image.copy() for image in images] for images in im_list_2d]
|
||||
|
||||
# Create
|
||||
if image_squares:
|
||||
width = int(image_height / 3)
|
||||
height = int(image_height / 3)
|
||||
images = [[image.resize((width, height)) for image in images] for images in im_list_2d]
|
||||
|
||||
# Resize images based on minimum size
|
||||
else:
|
||||
min_width = min([min([img.width for img in imgs]) for imgs in im_list_2d])
|
||||
min_height = min([min([img.height for img in imgs]) for imgs in im_list_2d])
|
||||
for i, imgs in enumerate(images):
|
||||
for j, img in enumerate(imgs):
|
||||
if img.height > img.width:
|
||||
images[i][j] = img.resize(
|
||||
(int(img.width * min_height / img.height), min_height),
|
||||
resample=0,
|
||||
)
|
||||
elif img.width > img.height:
|
||||
images[i][j] = img.resize((min_width, int(img.height * min_width / img.width)), resample=0)
|
||||
else:
|
||||
images[i][j] = img.resize((min_width, min_width))
|
||||
|
||||
# Resize grid image
|
||||
images = [get_concat_h_multi_resize(im_list_h) for im_list_h in images]
|
||||
img = get_concat_v_multi_resize(images)
|
||||
height_percentage = image_height / float(img.size[1])
|
||||
adjusted_width = int((float(img.size[0]) * float(height_percentage)))
|
||||
img = img.resize((adjusted_width, image_height), Image.Resampling.LANCZOS)
|
||||
|
||||
return img
|
||||
@@ -0,0 +1,104 @@
|
||||
import pandas as pd
|
||||
from transformers import pipeline
|
||||
from transformers.pipelines.base import Pipeline
|
||||
from scipy.sparse import csr_matrix
|
||||
from typing import Mapping, List, Tuple, Any
|
||||
from bertopic.representation._base import BaseRepresentation
|
||||
|
||||
|
||||
class ZeroShotClassification(BaseRepresentation):
|
||||
"""Zero-shot Classification on topic keywords with candidate labels.
|
||||
|
||||
Arguments:
|
||||
candidate_topics: A list of labels to assign to the topics if they
|
||||
exceed `min_prob`
|
||||
model: A transformers pipeline that should be initialized as
|
||||
"zero-shot-classification". For example,
|
||||
`pipeline("zero-shot-classification", model="facebook/bart-large-mnli")`
|
||||
pipeline_kwargs: Kwargs that you can pass to the transformers.pipeline
|
||||
when it is called. NOTE: Use `{"multi_label": True}`
|
||||
to extract multiple labels for each topic.
|
||||
min_prob: The minimum probability to assign a candidate label to a topic
|
||||
|
||||
Usage:
|
||||
|
||||
```python
|
||||
from bertopic.representation import ZeroShotClassification
|
||||
from bertopic import BERTopic
|
||||
|
||||
# Create your representation model
|
||||
candidate_topics = ["space and nasa", "bicycles", "sports"]
|
||||
representation_model = ZeroShotClassification(candidate_topics, model="facebook/bart-large-mnli")
|
||||
|
||||
# Use the representation model in BERTopic on top of the default pipeline
|
||||
topic_model = BERTopic(representation_model=representation_model)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
candidate_topics: List[str],
|
||||
model: str = "facebook/bart-large-mnli",
|
||||
pipeline_kwargs: Mapping[str, Any] = {},
|
||||
min_prob: float = 0.8,
|
||||
):
|
||||
self.candidate_topics = candidate_topics
|
||||
if isinstance(model, str):
|
||||
self.model = pipeline("zero-shot-classification", model=model)
|
||||
elif isinstance(model, Pipeline):
|
||||
self.model = model
|
||||
else:
|
||||
raise ValueError(
|
||||
"Make sure that the HF model that you"
|
||||
"pass is either a string referring to a"
|
||||
"HF model or a `transformers.pipeline` object."
|
||||
)
|
||||
self.pipeline_kwargs = pipeline_kwargs
|
||||
self.min_prob = min_prob
|
||||
|
||||
def extract_topics(
|
||||
self,
|
||||
topic_model,
|
||||
documents: pd.DataFrame,
|
||||
c_tf_idf: csr_matrix,
|
||||
topics: Mapping[str, List[Tuple[str, float]]],
|
||||
) -> Mapping[str, List[Tuple[str, float]]]:
|
||||
"""Extract topics.
|
||||
|
||||
Arguments:
|
||||
topic_model: Not used
|
||||
documents: Not used
|
||||
c_tf_idf: Not used
|
||||
topics: The candidate topics as calculated with c-TF-IDF
|
||||
|
||||
Returns:
|
||||
updated_topics: Updated topic representations
|
||||
"""
|
||||
# Classify topics
|
||||
topic_descriptions = [" ".join(list(zip(*topics[topic]))[0]) for topic in topics.keys()]
|
||||
classifications = self.model(topic_descriptions, self.candidate_topics, **self.pipeline_kwargs)
|
||||
|
||||
# Extract labels
|
||||
updated_topics = {}
|
||||
for topic, classification in zip(topics.keys(), classifications):
|
||||
topic_description = topics[topic]
|
||||
|
||||
# Multi-label assignment
|
||||
if self.pipeline_kwargs.get("multi_label"):
|
||||
topic_description = []
|
||||
for label, score in zip(classification["labels"], classification["scores"]):
|
||||
if score > self.min_prob:
|
||||
topic_description.append((label, score))
|
||||
|
||||
# Single label assignment
|
||||
elif classification["scores"][0] > self.min_prob:
|
||||
topic_description = [(classification["labels"][0], classification["scores"][0])]
|
||||
|
||||
# Make sure that 10 items are returned
|
||||
if len(topic_description) == 0:
|
||||
topic_description = topics[topic]
|
||||
elif len(topic_description) < 10:
|
||||
topic_description += [("", 0) for _ in range(10 - len(topic_description))]
|
||||
updated_topics[topic] = topic_description
|
||||
|
||||
return updated_topics
|
||||
Reference in New Issue
Block a user