Add BERTopic.
This commit is contained in:
@@ -0,0 +1,274 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
from scipy.sparse import csr_matrix
|
||||
from typing import Mapping, List, Tuple, Union
|
||||
from transformers.pipelines import Pipeline, pipeline
|
||||
|
||||
from bertopic.representation._mmr import mmr
|
||||
from bertopic.representation._base import BaseRepresentation
|
||||
|
||||
|
||||
class VisualRepresentation(BaseRepresentation):
|
||||
"""From a collection of representative documents, extract
|
||||
images to represent topics. These topics are represented by a
|
||||
collage of images.
|
||||
|
||||
Arguments:
|
||||
nr_repr_images: Number of representative images to extract
|
||||
nr_samples: The number of candidate documents to extract per cluster.
|
||||
image_height: The height of the resulting collage
|
||||
image_square: Whether to resize each image in the collage
|
||||
to a square. This can be visually more appealing
|
||||
if all input images are all almost squares.
|
||||
image_to_text_model: The model to caption images.
|
||||
batch_size: The number of images to pass to the
|
||||
`image_to_text_model`.
|
||||
|
||||
Usage:
|
||||
|
||||
```python
|
||||
from bertopic.representation import VisualRepresentation
|
||||
from bertopic import BERTopic
|
||||
|
||||
# The visual representation is typically not a core representation
|
||||
# and is advised to pass to BERTopic as an additional aspect.
|
||||
# Aspects can be labeled with dictionaries as shown below:
|
||||
representation_model = {
|
||||
"Visual_Aspect": VisualRepresentation()
|
||||
}
|
||||
|
||||
# Use the representation model in BERTopic as a separate aspect
|
||||
topic_model = BERTopic(representation_model=representation_model)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
nr_repr_images: int = 9,
|
||||
nr_samples: int = 500,
|
||||
image_height: Tuple[int, int] = 600,
|
||||
image_squares: bool = False,
|
||||
image_to_text_model: Union[str, Pipeline] = None,
|
||||
batch_size: int = 32,
|
||||
):
|
||||
self.nr_repr_images = nr_repr_images
|
||||
self.nr_samples = nr_samples
|
||||
self.image_height = image_height
|
||||
self.image_squares = image_squares
|
||||
|
||||
# Text-to-image model
|
||||
if isinstance(image_to_text_model, Pipeline):
|
||||
self.image_to_text_model = image_to_text_model
|
||||
elif isinstance(image_to_text_model, str):
|
||||
self.image_to_text_model = pipeline("image-to-text", model=image_to_text_model)
|
||||
elif image_to_text_model is None:
|
||||
self.image_to_text_model = None
|
||||
else:
|
||||
raise ValueError(
|
||||
"Please select a correct transformers pipeline. For example:"
|
||||
"pipeline('image-to-text', model='nlpconnect/vit-gpt2-image-captioning')"
|
||||
)
|
||||
self.batch_size = batch_size
|
||||
|
||||
def extract_topics(
|
||||
self,
|
||||
topic_model,
|
||||
documents: pd.DataFrame,
|
||||
c_tf_idf: csr_matrix,
|
||||
topics: Mapping[str, List[Tuple[str, float]]],
|
||||
) -> Mapping[str, List[Tuple[str, float]]]:
|
||||
"""Extract topics.
|
||||
|
||||
Arguments:
|
||||
topic_model: A BERTopic model
|
||||
documents: All input documents
|
||||
c_tf_idf: The topic c-TF-IDF representation
|
||||
topics: The candidate topics as calculated with c-TF-IDF
|
||||
|
||||
Returns:
|
||||
representative_images: Representative images per topic
|
||||
"""
|
||||
# Extract image ids of most representative documents
|
||||
images = documents["Image"].values.tolist()
|
||||
(_, _, _, repr_docs_ids) = topic_model._extract_representative_docs(
|
||||
c_tf_idf,
|
||||
documents,
|
||||
topics,
|
||||
nr_samples=self.nr_samples,
|
||||
nr_repr_docs=self.nr_repr_images,
|
||||
)
|
||||
unique_topics = sorted(list(topics.keys()))
|
||||
|
||||
# Combine representative images into a single representation
|
||||
representative_images = {}
|
||||
for topic in tqdm(unique_topics):
|
||||
# Get and order represetnative images
|
||||
sliced_examplars = repr_docs_ids[topic + topic_model._outliers]
|
||||
sliced_examplars = [sliced_examplars[i : i + 3] for i in range(0, len(sliced_examplars), 3)]
|
||||
images_to_combine = [
|
||||
[
|
||||
Image.open(images[index]) if isinstance(images[index], str) else images[index]
|
||||
for index in sub_indices
|
||||
]
|
||||
for sub_indices in sliced_examplars
|
||||
]
|
||||
|
||||
# Concatenate representative images
|
||||
representative_image = get_concat_tile_resize(images_to_combine, self.image_height, self.image_squares)
|
||||
representative_images[topic] = representative_image
|
||||
|
||||
# Make sure to properly close images
|
||||
if isinstance(images[0], str):
|
||||
for image_list in images_to_combine:
|
||||
for image in image_list:
|
||||
image.close()
|
||||
|
||||
return representative_images
|
||||
|
||||
def _convert_image_to_text(self, images: List[str], verbose: bool = False) -> List[str]:
|
||||
"""Convert a list of images to captions.
|
||||
|
||||
Arguments:
|
||||
images: A list of images or words to be converted to text.
|
||||
verbose: Controls the verbosity of the process
|
||||
|
||||
Returns:
|
||||
List of captions
|
||||
"""
|
||||
# Batch-wise image conversion
|
||||
if self.batch_size is not None:
|
||||
documents = []
|
||||
for batch in tqdm(self._chunks(images), disable=not verbose):
|
||||
outputs = self.image_to_text_model(batch)
|
||||
captions = [output[0]["generated_text"] for output in outputs]
|
||||
documents.extend(captions)
|
||||
|
||||
# Convert images to text
|
||||
else:
|
||||
outputs = self.image_to_text_model(images)
|
||||
documents = [output[0]["generated_text"] for output in outputs]
|
||||
|
||||
return documents
|
||||
|
||||
def image_to_text(self, documents: pd.DataFrame, embeddings: np.ndarray) -> pd.DataFrame:
|
||||
"""Convert images to text."""
|
||||
# Create image topic embeddings
|
||||
topics = documents.Topic.values.tolist()
|
||||
images = documents.Image.values.tolist()
|
||||
df = pd.DataFrame(np.hstack([np.array(topics).reshape(-1, 1), embeddings]))
|
||||
image_topic_embeddings = df.groupby(0).mean().values
|
||||
|
||||
# Extract image centroids
|
||||
image_centroids = {}
|
||||
unique_topics = sorted(list(set(topics)))
|
||||
for topic, topic_embedding in zip(unique_topics, image_topic_embeddings):
|
||||
indices = np.array([index for index, t in enumerate(topics) if t == topic])
|
||||
top_n = min([self.nr_repr_images, len(indices)])
|
||||
indices = mmr(
|
||||
topic_embedding.reshape(1, -1),
|
||||
embeddings[indices],
|
||||
indices,
|
||||
top_n=top_n,
|
||||
diversity=0.1,
|
||||
)
|
||||
image_centroids[topic] = indices
|
||||
|
||||
# Extract documents
|
||||
documents = pd.DataFrame(columns=["Document", "ID", "Topic", "Image"])
|
||||
current_id = 0
|
||||
for topic, image_ids in tqdm(image_centroids.items()):
|
||||
selected_images = [
|
||||
Image.open(images[index]) if isinstance(images[index], str) else images[index] for index in image_ids
|
||||
]
|
||||
text = self._convert_image_to_text(selected_images)
|
||||
|
||||
for doc, image_id in zip(text, image_ids):
|
||||
documents.loc[len(documents), :] = [
|
||||
doc,
|
||||
current_id,
|
||||
topic,
|
||||
images[image_id],
|
||||
]
|
||||
current_id += 1
|
||||
|
||||
# Properly close images
|
||||
if isinstance(images[image_ids[0]], str):
|
||||
for image in selected_images:
|
||||
image.close()
|
||||
|
||||
return documents
|
||||
|
||||
def _chunks(self, images):
|
||||
for i in range(0, len(images), self.batch_size):
|
||||
yield images[i : i + self.batch_size]
|
||||
|
||||
|
||||
def get_concat_h_multi_resize(im_list):
|
||||
"""Code adapted from: https://note.nkmk.me/en/python-pillow-concat-images/."""
|
||||
min_height = min(im.height for im in im_list)
|
||||
min_height = max(im.height for im in im_list)
|
||||
im_list_resize = []
|
||||
for im in im_list:
|
||||
im.resize((int(im.width * min_height / im.height), min_height), resample=0)
|
||||
im_list_resize.append(im)
|
||||
|
||||
total_width = sum(im.width for im in im_list_resize)
|
||||
dst = Image.new("RGB", (total_width, min_height), (255, 255, 255))
|
||||
pos_x = 0
|
||||
for im in im_list_resize:
|
||||
dst.paste(im, (pos_x, 0))
|
||||
pos_x += im.width
|
||||
return dst
|
||||
|
||||
|
||||
def get_concat_v_multi_resize(im_list):
|
||||
"""Code adapted from: https://note.nkmk.me/en/python-pillow-concat-images/."""
|
||||
min_width = min(im.width for im in im_list)
|
||||
min_width = max(im.width for im in im_list)
|
||||
im_list_resize = [im.resize((min_width, int(im.height * min_width / im.width)), resample=0) for im in im_list]
|
||||
total_height = sum(im.height for im in im_list_resize)
|
||||
dst = Image.new("RGB", (min_width, total_height), (255, 255, 255))
|
||||
pos_y = 0
|
||||
for im in im_list_resize:
|
||||
dst.paste(im, (0, pos_y))
|
||||
pos_y += im.height
|
||||
return dst
|
||||
|
||||
|
||||
def get_concat_tile_resize(im_list_2d, image_height=600, image_squares=False):
|
||||
"""Code adapted from: https://note.nkmk.me/en/python-pillow-concat-images/."""
|
||||
images = [[image.copy() for image in images] for images in im_list_2d]
|
||||
|
||||
# Create
|
||||
if image_squares:
|
||||
width = int(image_height / 3)
|
||||
height = int(image_height / 3)
|
||||
images = [[image.resize((width, height)) for image in images] for images in im_list_2d]
|
||||
|
||||
# Resize images based on minimum size
|
||||
else:
|
||||
min_width = min([min([img.width for img in imgs]) for imgs in im_list_2d])
|
||||
min_height = min([min([img.height for img in imgs]) for imgs in im_list_2d])
|
||||
for i, imgs in enumerate(images):
|
||||
for j, img in enumerate(imgs):
|
||||
if img.height > img.width:
|
||||
images[i][j] = img.resize(
|
||||
(int(img.width * min_height / img.height), min_height),
|
||||
resample=0,
|
||||
)
|
||||
elif img.width > img.height:
|
||||
images[i][j] = img.resize((min_width, int(img.height * min_width / img.width)), resample=0)
|
||||
else:
|
||||
images[i][j] = img.resize((min_width, min_width))
|
||||
|
||||
# Resize grid image
|
||||
images = [get_concat_h_multi_resize(im_list_h) for im_list_h in images]
|
||||
img = get_concat_v_multi_resize(images)
|
||||
height_percentage = image_height / float(img.size[1])
|
||||
adjusted_width = int((float(img.size[0]) * float(height_percentage)))
|
||||
img = img.resize((adjusted_width, image_height), Image.Resampling.LANCZOS)
|
||||
|
||||
return img
|
||||
Reference in New Issue
Block a user