105 lines
4.0 KiB
Python
105 lines
4.0 KiB
Python
import pandas as pd
|
|
from transformers import pipeline
|
|
from transformers.pipelines.base import Pipeline
|
|
from scipy.sparse import csr_matrix
|
|
from typing import Mapping, List, Tuple, Any
|
|
from bertopic.representation._base import BaseRepresentation
|
|
|
|
|
|
class ZeroShotClassification(BaseRepresentation):
|
|
"""Zero-shot Classification on topic keywords with candidate labels.
|
|
|
|
Arguments:
|
|
candidate_topics: A list of labels to assign to the topics if they
|
|
exceed `min_prob`
|
|
model: A transformers pipeline that should be initialized as
|
|
"zero-shot-classification". For example,
|
|
`pipeline("zero-shot-classification", model="facebook/bart-large-mnli")`
|
|
pipeline_kwargs: Kwargs that you can pass to the transformers.pipeline
|
|
when it is called. NOTE: Use `{"multi_label": True}`
|
|
to extract multiple labels for each topic.
|
|
min_prob: The minimum probability to assign a candidate label to a topic
|
|
|
|
Usage:
|
|
|
|
```python
|
|
from bertopic.representation import ZeroShotClassification
|
|
from bertopic import BERTopic
|
|
|
|
# Create your representation model
|
|
candidate_topics = ["space and nasa", "bicycles", "sports"]
|
|
representation_model = ZeroShotClassification(candidate_topics, model="facebook/bart-large-mnli")
|
|
|
|
# Use the representation model in BERTopic on top of the default pipeline
|
|
topic_model = BERTopic(representation_model=representation_model)
|
|
```
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
candidate_topics: List[str],
|
|
model: str = "facebook/bart-large-mnli",
|
|
pipeline_kwargs: Mapping[str, Any] = {},
|
|
min_prob: float = 0.8,
|
|
):
|
|
self.candidate_topics = candidate_topics
|
|
if isinstance(model, str):
|
|
self.model = pipeline("zero-shot-classification", model=model)
|
|
elif isinstance(model, Pipeline):
|
|
self.model = model
|
|
else:
|
|
raise ValueError(
|
|
"Make sure that the HF model that you"
|
|
"pass is either a string referring to a"
|
|
"HF model or a `transformers.pipeline` object."
|
|
)
|
|
self.pipeline_kwargs = pipeline_kwargs
|
|
self.min_prob = min_prob
|
|
|
|
def extract_topics(
|
|
self,
|
|
topic_model,
|
|
documents: pd.DataFrame,
|
|
c_tf_idf: csr_matrix,
|
|
topics: Mapping[str, List[Tuple[str, float]]],
|
|
) -> Mapping[str, List[Tuple[str, float]]]:
|
|
"""Extract topics.
|
|
|
|
Arguments:
|
|
topic_model: Not used
|
|
documents: Not used
|
|
c_tf_idf: Not used
|
|
topics: The candidate topics as calculated with c-TF-IDF
|
|
|
|
Returns:
|
|
updated_topics: Updated topic representations
|
|
"""
|
|
# Classify topics
|
|
topic_descriptions = [" ".join(list(zip(*topics[topic]))[0]) for topic in topics.keys()]
|
|
classifications = self.model(topic_descriptions, self.candidate_topics, **self.pipeline_kwargs)
|
|
|
|
# Extract labels
|
|
updated_topics = {}
|
|
for topic, classification in zip(topics.keys(), classifications):
|
|
topic_description = topics[topic]
|
|
|
|
# Multi-label assignment
|
|
if self.pipeline_kwargs.get("multi_label"):
|
|
topic_description = []
|
|
for label, score in zip(classification["labels"], classification["scores"]):
|
|
if score > self.min_prob:
|
|
topic_description.append((label, score))
|
|
|
|
# Single label assignment
|
|
elif classification["scores"][0] > self.min_prob:
|
|
topic_description = [(classification["labels"][0], classification["scores"][0])]
|
|
|
|
# Make sure that 10 items are returned
|
|
if len(topic_description) == 0:
|
|
topic_description = topics[topic]
|
|
elif len(topic_description) < 10:
|
|
topic_description += [("", 0) for _ in range(10 - len(topic_description))]
|
|
updated_topics[topic] = topic_description
|
|
|
|
return updated_topics
|