Add BERTopic.
This commit is contained in:
@@ -0,0 +1,330 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Callable, List, Union
|
||||
from scipy.sparse import csr_matrix
|
||||
from scipy.cluster import hierarchy as sch
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
|
||||
from bertopic._utils import select_topic_representation
|
||||
|
||||
import plotly.graph_objects as go
|
||||
import plotly.figure_factory as ff
|
||||
|
||||
from bertopic._utils import validate_distance_matrix
|
||||
|
||||
|
||||
def visualize_hierarchy(
|
||||
topic_model,
|
||||
orientation: str = "left",
|
||||
topics: List[int] = None,
|
||||
top_n_topics: int = None,
|
||||
use_ctfidf: bool = True,
|
||||
custom_labels: Union[bool, str] = False,
|
||||
title: str = "<b>Hierarchical Clustering</b>",
|
||||
width: int = 1000,
|
||||
height: int = 600,
|
||||
hierarchical_topics: pd.DataFrame = None,
|
||||
linkage_function: Callable[[csr_matrix], np.ndarray] = None,
|
||||
distance_function: Callable[[csr_matrix], csr_matrix] = None,
|
||||
color_threshold: int = 1,
|
||||
) -> go.Figure:
|
||||
"""Visualize a hierarchical structure of the topics.
|
||||
|
||||
A ward linkage function is used to perform the
|
||||
hierarchical clustering based on the cosine distance
|
||||
matrix between topic embeddings (either c-TF-IDF or the embeddings from the embedding model).
|
||||
|
||||
Arguments:
|
||||
topic_model: A fitted BERTopic instance.
|
||||
orientation: The orientation of the figure.
|
||||
Either 'left' or 'bottom'
|
||||
topics: A selection of topics to visualize
|
||||
top_n_topics: Only select the top n most frequent topics
|
||||
use_ctfidf: Whether to calculate distances between topics based on c-TF-IDF embeddings. If False, the embeddings
|
||||
from the embedding model are used.
|
||||
custom_labels: If bool, whether to use custom topic labels that were defined using
|
||||
`topic_model.set_topic_labels`.
|
||||
If `str`, it uses labels from other aspects, e.g., "Aspect1".
|
||||
NOTE: Custom labels are only generated for the original
|
||||
un-merged topics.
|
||||
title: Title of the plot.
|
||||
width: The width of the figure. Only works if orientation is set to 'left'
|
||||
height: The height of the figure. Only works if orientation is set to 'bottom'
|
||||
hierarchical_topics: A dataframe that contains a hierarchy of topics
|
||||
represented by their parents and their children.
|
||||
NOTE: The hierarchical topic names are only visualized
|
||||
if both `topics` and `top_n_topics` are not set.
|
||||
linkage_function: The linkage function to use. Default is:
|
||||
`lambda x: sch.linkage(x, 'ward', optimal_ordering=True)`
|
||||
NOTE: Make sure to use the same `linkage_function` as used
|
||||
in `topic_model.hierarchical_topics`.
|
||||
distance_function: The distance function to use on the c-TF-IDF matrix. Default is:
|
||||
`lambda x: 1 - cosine_similarity(x)`.
|
||||
You can pass any function that returns either a square matrix of
|
||||
shape (n_samples, n_samples) with zeros on the diagonal and
|
||||
non-negative values or condensed distance matrix of shape
|
||||
(n_samples * (n_samples - 1) / 2,) containing the upper
|
||||
triangular of the distance matrix.
|
||||
NOTE: Make sure to use the same `distance_function` as used
|
||||
in `topic_model.hierarchical_topics`.
|
||||
color_threshold: Value at which the separation of clusters will be made which
|
||||
will result in different colors for different clusters.
|
||||
A higher value will typically lead in less colored clusters.
|
||||
|
||||
Returns:
|
||||
fig: A plotly figure
|
||||
|
||||
Examples:
|
||||
To visualize the hierarchical structure of
|
||||
topics simply run:
|
||||
|
||||
```python
|
||||
topic_model.visualize_hierarchy()
|
||||
```
|
||||
|
||||
If you also want the labels visualized of hierarchical topics,
|
||||
run the following:
|
||||
|
||||
```python
|
||||
# Extract hierarchical topics and their representations
|
||||
hierarchical_topics = topic_model.hierarchical_topics(docs)
|
||||
|
||||
# Visualize these representations
|
||||
topic_model.visualize_hierarchy(hierarchical_topics=hierarchical_topics)
|
||||
```
|
||||
|
||||
If you want to save the resulting figure:
|
||||
|
||||
```python
|
||||
fig = topic_model.visualize_hierarchy()
|
||||
fig.write_html("path/to/file.html")
|
||||
```
|
||||
<iframe src="../../getting_started/visualization/hierarchy.html"
|
||||
style="width:1000px; height: 680px; border: 0px;""></iframe>
|
||||
"""
|
||||
if distance_function is None:
|
||||
distance_function = lambda x: 1 - cosine_similarity(x)
|
||||
|
||||
if linkage_function is None:
|
||||
linkage_function = lambda x: sch.linkage(x, "ward", optimal_ordering=True)
|
||||
|
||||
# Select topics based on top_n and topics args
|
||||
freq_df = topic_model.get_topic_freq()
|
||||
freq_df = freq_df.loc[freq_df.Topic != -1, :]
|
||||
if topics is not None:
|
||||
topics = list(topics)
|
||||
elif top_n_topics is not None:
|
||||
topics = sorted(freq_df.Topic.to_list()[:top_n_topics])
|
||||
else:
|
||||
topics = sorted(freq_df.Topic.to_list())
|
||||
|
||||
# Select embeddings
|
||||
all_topics = sorted(list(topic_model.get_topics().keys()))
|
||||
indices = np.array([all_topics.index(topic) for topic in topics])
|
||||
|
||||
# Select topic embeddings
|
||||
embeddings = select_topic_representation(topic_model.c_tf_idf_, topic_model.topic_embeddings_, use_ctfidf)[0][
|
||||
indices
|
||||
]
|
||||
|
||||
# Annotations
|
||||
if hierarchical_topics is not None and len(topics) == len(freq_df.Topic.to_list()):
|
||||
annotations = _get_annotations(
|
||||
topic_model=topic_model,
|
||||
hierarchical_topics=hierarchical_topics,
|
||||
embeddings=embeddings,
|
||||
distance_function=distance_function,
|
||||
linkage_function=linkage_function,
|
||||
orientation=orientation,
|
||||
custom_labels=custom_labels,
|
||||
)
|
||||
else:
|
||||
annotations = None
|
||||
|
||||
# wrap distance function to validate input and return a condensed distance matrix
|
||||
distance_function_viz = lambda x: validate_distance_matrix(distance_function(x), embeddings.shape[0])
|
||||
# Create dendogram
|
||||
fig = ff.create_dendrogram(
|
||||
embeddings,
|
||||
orientation=orientation,
|
||||
distfun=distance_function_viz,
|
||||
linkagefun=linkage_function,
|
||||
hovertext=annotations,
|
||||
color_threshold=color_threshold,
|
||||
)
|
||||
|
||||
# Create nicer labels
|
||||
axis = "yaxis" if orientation == "left" else "xaxis"
|
||||
if isinstance(custom_labels, str):
|
||||
new_labels = [
|
||||
[[str(x), None]] + topic_model.topic_aspects_[custom_labels][x] for x in fig.layout[axis]["ticktext"]
|
||||
]
|
||||
new_labels = ["_".join([label[0] for label in labels[:4]]) for labels in new_labels]
|
||||
new_labels = [label if len(label) < 30 else label[:27] + "..." for label in new_labels]
|
||||
elif topic_model.custom_labels_ is not None and custom_labels:
|
||||
new_labels = [
|
||||
topic_model.custom_labels_[topics[int(x)] + topic_model._outliers] for x in fig.layout[axis]["ticktext"]
|
||||
]
|
||||
else:
|
||||
new_labels = [
|
||||
[[str(topics[int(x)]), None]] + topic_model.get_topic(topics[int(x)]) for x in fig.layout[axis]["ticktext"]
|
||||
]
|
||||
new_labels = ["_".join([label[0] for label in labels[:4]]) for labels in new_labels]
|
||||
new_labels = [label if len(label) < 30 else label[:27] + "..." for label in new_labels]
|
||||
|
||||
# Stylize layout
|
||||
fig.update_layout(
|
||||
plot_bgcolor="#ECEFF1",
|
||||
template="plotly_white",
|
||||
title={
|
||||
"text": f"{title}",
|
||||
"x": 0.5,
|
||||
"xanchor": "center",
|
||||
"yanchor": "top",
|
||||
"font": dict(size=22, color="Black"),
|
||||
},
|
||||
hoverlabel=dict(bgcolor="white", font_size=16, font_family="Rockwell"),
|
||||
)
|
||||
|
||||
# Stylize orientation
|
||||
if orientation == "left":
|
||||
fig.update_layout(
|
||||
height=200 + (15 * len(topics)),
|
||||
width=width,
|
||||
yaxis=dict(tickmode="array", ticktext=new_labels),
|
||||
)
|
||||
|
||||
# Fix empty space on the bottom of the graph
|
||||
y_max = max([trace["y"].max() + 5 for trace in fig["data"]])
|
||||
y_min = min([trace["y"].min() - 5 for trace in fig["data"]])
|
||||
fig.update_layout(yaxis=dict(range=[y_min, y_max]))
|
||||
|
||||
else:
|
||||
fig.update_layout(
|
||||
width=200 + (15 * len(topics)),
|
||||
height=height,
|
||||
xaxis=dict(tickmode="array", ticktext=new_labels),
|
||||
)
|
||||
|
||||
if hierarchical_topics is not None:
|
||||
for index in [0, 3]:
|
||||
axis = "x" if orientation == "left" else "y"
|
||||
xs = [data["x"][index] for data in fig.data if (data["text"] and data[axis][index] > 0)]
|
||||
ys = [data["y"][index] for data in fig.data if (data["text"] and data[axis][index] > 0)]
|
||||
hovertext = [data["text"][index] for data in fig.data if (data["text"] and data[axis][index] > 0)]
|
||||
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=xs,
|
||||
y=ys,
|
||||
marker_color="black",
|
||||
hovertext=hovertext,
|
||||
hoverinfo="text",
|
||||
mode="markers",
|
||||
showlegend=False,
|
||||
)
|
||||
)
|
||||
return fig
|
||||
|
||||
|
||||
def _get_annotations(
|
||||
topic_model,
|
||||
hierarchical_topics: pd.DataFrame,
|
||||
embeddings: csr_matrix,
|
||||
linkage_function: Callable[[csr_matrix], np.ndarray],
|
||||
distance_function: Callable[[csr_matrix], csr_matrix],
|
||||
orientation: str,
|
||||
custom_labels: bool = False,
|
||||
) -> List[List[str]]:
|
||||
"""Get annotations by replicating linkage function calculation in scipy.
|
||||
|
||||
Arguments:
|
||||
topic_model: A fitted BERTopic instance.
|
||||
hierarchical_topics: A dataframe that contains a hierarchy of topics
|
||||
represented by their parents and their children.
|
||||
NOTE: The hierarchical topic names are only visualized
|
||||
if both `topics` and `top_n_topics` are not set.
|
||||
embeddings: The c-TF-IDF matrix on which to model the hierarchy
|
||||
linkage_function: The linkage function to use. Default is:
|
||||
`lambda x: sch.linkage(x, 'ward', optimal_ordering=True)`
|
||||
NOTE: Make sure to use the same `linkage_function` as used
|
||||
in `topic_model.hierarchical_topics`.
|
||||
distance_function: The distance function to use on the c-TF-IDF matrix. Default is:
|
||||
`lambda x: 1 - cosine_similarity(x)`.
|
||||
You can pass any function that returns either a square matrix of
|
||||
shape (n_samples, n_samples) with zeros on the diagonal and
|
||||
non-negative values or condensed distance matrix of shape
|
||||
(n_samples * (n_samples - 1) / 2,) containing the upper
|
||||
triangular of the distance matrix.
|
||||
NOTE: Make sure to use the same `distance_function` as used
|
||||
in `topic_model.hierarchical_topics`.
|
||||
orientation: The orientation of the figure.
|
||||
Either 'left' or 'bottom'
|
||||
custom_labels: Whether to use custom topic labels that were defined using
|
||||
`topic_model.set_topic_labels`.
|
||||
NOTE: Custom labels are only generated for the original
|
||||
un-merged topics.
|
||||
|
||||
Returns:
|
||||
text_annotations: Annotations to be used within Plotly's `ff.create_dendogram`
|
||||
"""
|
||||
df = hierarchical_topics.loc[hierarchical_topics.Parent_Name != "Top", :]
|
||||
|
||||
# Calculate distance
|
||||
X = distance_function(embeddings)
|
||||
X = validate_distance_matrix(X, embeddings.shape[0])
|
||||
|
||||
# Calculate linkage and generate dendrogram
|
||||
Z = linkage_function(X)
|
||||
P = sch.dendrogram(Z, orientation=orientation, no_plot=True)
|
||||
|
||||
# store topic no.(leaves) corresponding to the x-ticks in dendrogram
|
||||
x_ticks = np.arange(5, len(P["leaves"]) * 10 + 5, 10)
|
||||
x_topic = dict(zip(P["leaves"], x_ticks))
|
||||
|
||||
topic_vals = dict()
|
||||
for key, val in x_topic.items():
|
||||
topic_vals[val] = [key]
|
||||
|
||||
parent_topic = dict(zip(df.Parent_ID, df.Topics))
|
||||
|
||||
# loop through every trace (scatter plot) in dendrogram
|
||||
text_annotations = []
|
||||
for index, trace in enumerate(P["icoord"]):
|
||||
fst_topic = topic_vals[trace[0]]
|
||||
scnd_topic = topic_vals[trace[2]]
|
||||
|
||||
if len(fst_topic) == 1:
|
||||
if isinstance(custom_labels, str):
|
||||
fst_name = f"{fst_topic[0]}_" + "_".join(
|
||||
list(zip(*topic_model.topic_aspects_[custom_labels][fst_topic[0]]))[0][:3]
|
||||
)
|
||||
elif topic_model.custom_labels_ is not None and custom_labels:
|
||||
fst_name = topic_model.custom_labels_[fst_topic[0] + topic_model._outliers]
|
||||
else:
|
||||
fst_name = "_".join([word for word, _ in topic_model.get_topic(fst_topic[0])][:5])
|
||||
else:
|
||||
for key, value in parent_topic.items():
|
||||
if set(value) == set(fst_topic):
|
||||
fst_name = df.loc[df.Parent_ID == key, "Parent_Name"].values[0]
|
||||
|
||||
if len(scnd_topic) == 1:
|
||||
if isinstance(custom_labels, str):
|
||||
scnd_name = f"{scnd_topic[0]}_" + "_".join(
|
||||
list(zip(*topic_model.topic_aspects_[custom_labels][scnd_topic[0]]))[0][:3]
|
||||
)
|
||||
elif topic_model.custom_labels_ is not None and custom_labels:
|
||||
scnd_name = topic_model.custom_labels_[scnd_topic[0] + topic_model._outliers]
|
||||
else:
|
||||
scnd_name = "_".join([word for word, _ in topic_model.get_topic(scnd_topic[0])][:5])
|
||||
else:
|
||||
for key, value in parent_topic.items():
|
||||
if set(value) == set(scnd_topic):
|
||||
scnd_name = df.loc[df.Parent_ID == key, "Parent_Name"].values[0]
|
||||
|
||||
text_annotations.append([fst_name, "", "", scnd_name])
|
||||
|
||||
center = (trace[0] + trace[2]) / 2
|
||||
topic_vals[center] = fst_topic + scnd_topic
|
||||
|
||||
return text_annotations
|
||||
Reference in New Issue
Block a user