avsolatorio's picture
Add device
09767db verified
|
raw
history blame
5.74 kB
metadata
license: mit

The WBG Doc Topic Container

from transformers import pipeline
from tqdm.auto import tqdm
import pandas as pd
from transformers import AutoTokenizer


class WBGDocTopic:
    """
    A class to handle document topic suggestion using multiple pre-trained text classification models.

    This class loads a set of text classification models from Hugging Face's model hub and 
    provides a method to suggest topics for input documents based on the aggregated classification 
    results from all the models.

    Attributes:
    -----------
    classifiers : dict
        A dictionary mapping model names to corresponding classification pipelines. It holds 
        instances of Hugging Face's `pipeline` used for text classification.

    Methods:
    --------
    __init__(classifiers: dict = None)
        Initializes the `WBGDocTopic` instance. If no classifiers are provided, it loads a default 
        set of classifiers by calling `load_classifiers`.
    
    load_classifiers()
        Loads a predefined set of document topic classifiers into the `classifiers` dictionary. 
        It uses `tqdm` to display progress as the classifiers are loaded.

    suggest_topics(input_docs: str | list[str]) -> list
        Suggests topics for the given document or list of documents. It runs each document 
        through all classifiers, averages their scores, and returns a list of dictionaries where each 
        dictionary contains the mean and standard deviation of the topic scores per document.
        
        Parameters:
        -----------
        input_docs : str or list of str
            A single document or a list of documents for which to suggest topics.

        Returns:
        --------
        list
            A list of dictionaries, where each dictionary represents the suggested topics for 
            each document, along with the mean and standard deviation of the topic classification scores.
    """

    def __init__(self, classifiers: dict = None, device: str = None
        self.classifiers = classifiers or {}
        self.device = device

        if classifiers is None:
            self.load_classifiers()

    def load_classifiers(self):
        num_evals = 5
        num_train = 5

        tokenizer = AutoTokenizer.from_pretrained("avsolatorio/doc-topic-model_eval-04_train-03")

        for i in tqdm(range(num_evals)):
            for j in tqdm(range(num_train)):
                if i == j:
                    continue

                model_name = f"avsolatorio/doc-topic-model_eval-{i:02}_train-{j:02}"
                classifier = pipeline("text-classification", model=model_name, tokenizer=tokenizer, top_k=None, device=self.device)

                self.classifiers[model_name] = classifier

    def suggest_topics(self, input_docs: str | list[str]):
        if isinstance(input_docs, str):
            input_docs = [input_docs]

        doc_outs = {i: [] for i in range(len(input_docs))}
        topics = []

        for _, classifier in self.classifiers.items():
            for doc_idx, doc in enumerate(classifier(input_docs)):
                doc_outs[doc_idx].append(pd.DataFrame.from_records(doc, index="label"))

        for doc_idx, outs in doc_outs.items():
            all_scores = pd.concat(outs, axis=1)
            mean_probs = all_scores.mean(axis=1).sort_values(ascending=False)
            std_probs = all_scores.std(axis=1).loc[mean_probs.index]
            output = pd.DataFrame({"score_mean": mean_probs, "score_std": std_probs})

            output["doc_idx"] = doc_idx
            output.reset_index(inplace=True)

            topics.append(output.to_dict(orient="records"))

        return topics

Using the WBGDocTopic model

import nltk

# Download the nltk data if not present
nltk.download('punkt_tab')
nltk.download('punkt')

# Load the sent_tokenize method for quick sentence extraction
from nltk import sent_tokenize

# Process the input
sample_text = """A growing literature attributes gender inequality in labor market outcomes in part to the reduction in female labor supply after childbirth, the child penalty. However, if social norms constrain married women’s activities outside the home, then marriage can independently reduce employment, even in the absence childbearing. Given the correlation in timing between childbirth and marriage, conventional estimates of child penalties will conflate these two effects. The paper studies the marriage penalty in South Asia, a context featuring conservative gender norms and low female labor force participation. The study introduces a split-sample, pseudo-panel approach that allows for the separation of marriage and child penalties even in the absence of individual-level panel data. Marriage reduces women’s labor force participation in South Asia by 12 percentage points, whereas the marginal penalty of childbearing is small. Consistent with the central roles of both opportunity costs and social norms, the marriage penalty is smaller among cohorts with higher education and less conservative gender attitudes."""
sents = sent_tokenize(inp)

# Create the instance which will load the models.
# Set the device to "cuda" if you want to use a GPU.
dtopic_model = WBGDocTopic(device=None)

# Infer the topics and scores
outs = dtopic_model.suggest_topics(sents)
outs
# [[{'label': 'Gender',
#   'score_mean': 0.8776359841227531,
#   'score_std': 0.13074095501538094,
#   'doc_idx': 0},
#  {'label': 'Labor Markets',
#   'score_mean': 0.20742715448141097,
#   'score_std': 0.20991565414467345,
#   'doc_idx': 0},
#  {'label': "Girls' Education",
#   'score_mean': 0.19432228063233198,
#   'score_std': 0.21148874269682794,
#   'doc_idx': 0}, ...]]