Add BERTopic.
This commit is contained in:
@@ -0,0 +1,100 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
try:
|
||||
from pandas.io.formats.style import Styler # noqa: F401
|
||||
|
||||
HAS_JINJA = True
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
HAS_JINJA = False
|
||||
|
||||
|
||||
def visualize_approximate_distribution(
|
||||
topic_model,
|
||||
document: str,
|
||||
topic_token_distribution: np.ndarray,
|
||||
normalize: bool = False,
|
||||
):
|
||||
"""Visualize the topic distribution calculated by `.approximate_topic_distribution`
|
||||
on a token level. Thereby indicating the extend to which a certain word or phrases belong
|
||||
to a specific topic. The assumption here is that a single word can belong to multiple
|
||||
similar topics and as such give information about the broader set of topics within
|
||||
a single document.
|
||||
|
||||
Note:
|
||||
This function will return a stylized pandas dataframe if Jinja2 is installed. If not,
|
||||
it will only return a pandas dataframe without color highlighting. To install jinja:
|
||||
|
||||
`pip install jinja2`
|
||||
|
||||
Arguments:
|
||||
topic_model: A fitted BERTopic instance.
|
||||
document: The document for which you want to visualize
|
||||
the approximated topic distribution.
|
||||
topic_token_distribution: The topic-token distribution of the document as
|
||||
extracted by `.approximate_topic_distribution`
|
||||
normalize: Whether to normalize, between 0 and 1 (summing to 1), the
|
||||
topic distribution values.
|
||||
|
||||
Returns:
|
||||
df: A stylized dataframe indicating the best fitting topics
|
||||
for each token.
|
||||
|
||||
Examples:
|
||||
```python
|
||||
# Calculate the topic distributions on a token level
|
||||
# Note that we need to have `calculate_token_level=True`
|
||||
topic_distr, topic_token_distr = topic_model.approximate_distribution(
|
||||
docs, calculate_token_level=True
|
||||
)
|
||||
|
||||
# Visualize the approximated topic distributions
|
||||
df = topic_model.visualize_approximate_distribution(docs[0], topic_token_distr[0])
|
||||
df
|
||||
```
|
||||
|
||||
To revert this stylized dataframe back to a regular dataframe,
|
||||
you can run the following:
|
||||
|
||||
```python
|
||||
df.data.columns = [column.strip() for column in df.data.columns]
|
||||
df = df.data
|
||||
```
|
||||
"""
|
||||
# Tokenize document
|
||||
analyzer = topic_model.vectorizer_model.build_tokenizer()
|
||||
tokens = analyzer(document)
|
||||
|
||||
if len(tokens) == 0:
|
||||
raise ValueError("Make sure that your document contains at least 1 token.")
|
||||
|
||||
# Prepare dataframe with results
|
||||
if normalize:
|
||||
df = pd.DataFrame(topic_token_distribution / topic_token_distribution.sum()).T
|
||||
else:
|
||||
df = pd.DataFrame(topic_token_distribution).T
|
||||
|
||||
df.columns = [f"{token}_{i}" for i, token in enumerate(tokens)]
|
||||
df.columns = [f"{token}{' ' * i}" for i, token in enumerate(tokens)]
|
||||
df.index = list(topic_model.topic_labels_.values())[topic_model._outliers :]
|
||||
df = df.loc[(df.sum(axis=1) != 0), :]
|
||||
|
||||
# Style the resulting dataframe
|
||||
def text_color(val):
|
||||
color = "white" if val == 0 else "black"
|
||||
return "color: %s" % color
|
||||
|
||||
def highligh_color(data, color="white"):
|
||||
attr = "background-color: {}".format(color)
|
||||
return pd.DataFrame(np.where(data == 0, attr, ""), index=data.index, columns=data.columns)
|
||||
|
||||
if len(df) == 0:
|
||||
return df
|
||||
elif HAS_JINJA:
|
||||
df = (
|
||||
df.style.format("{:.3f}")
|
||||
.background_gradient(cmap="Blues", axis=None)
|
||||
.applymap(lambda x: text_color(x))
|
||||
.apply(highligh_color, axis=None)
|
||||
)
|
||||
return df
|
||||
Reference in New Issue
Block a user