--- license: mit --- # The WBG Doc Topic Container ```Python from transformers import pipeline from tqdm.auto import tqdm import pandas as pd from transformers import AutoTokenizer class WBGDocTopic: """ A class to handle document topic suggestion using multiple pre-trained text classification models. This class loads a set of text classification models from Hugging Face's model hub and provides a method to suggest topics for input documents based on the aggregated classification results from all the models. Attributes: ----------- classifiers : dict A dictionary mapping model names to corresponding classification pipelines. It holds instances of Hugging Face's `pipeline` used for text classification. Methods: -------- __init__(classifiers: dict = None) Initializes the `WBGDocTopic` instance. If no classifiers are provided, it loads a default set of classifiers by calling `load_classifiers`. load_classifiers() Loads a predefined set of document topic classifiers into the `classifiers` dictionary. It uses `tqdm` to display progress as the classifiers are loaded. suggest_topics(input_docs: str | list[str]) -> list Suggests topics for the given document or list of documents. It runs each document through all classifiers, averages their scores, and returns a list of dictionaries where each dictionary contains the mean and standard deviation of the topic scores per document. Parameters: ----------- input_docs : str or list of str A single document or a list of documents for which to suggest topics. Returns: -------- list A list of dictionaries, where each dictionary represents the suggested topics for each document, along with the mean and standard deviation of the topic classification scores. """ def __init__(self, classifiers: dict = None, device: str = None self.classifiers = classifiers or {} self.device = device if classifiers is None: self.load_classifiers() def load_classifiers(self): num_evals = 5 num_train = 5 tokenizer = AutoTokenizer.from_pretrained("avsolatorio/doc-topic-model_eval-04_train-03") for i in tqdm(range(num_evals)): for j in tqdm(range(num_train)): if i == j: continue model_name = f"avsolatorio/doc-topic-model_eval-{i:02}_train-{j:02}" classifier = pipeline("text-classification", model=model_name, tokenizer=tokenizer, top_k=None, device=self.device) self.classifiers[model_name] = classifier def suggest_topics(self, input_docs: str | list[str]): if isinstance(input_docs, str): input_docs = [input_docs] doc_outs = {i: [] for i in range(len(input_docs))} topics = [] for _, classifier in self.classifiers.items(): for doc_idx, doc in enumerate(classifier(input_docs)): doc_outs[doc_idx].append(pd.DataFrame.from_records(doc, index="label")) for doc_idx, outs in doc_outs.items(): all_scores = pd.concat(outs, axis=1) mean_probs = all_scores.mean(axis=1).sort_values(ascending=False) std_probs = all_scores.std(axis=1).loc[mean_probs.index] output = pd.DataFrame({"score_mean": mean_probs, "score_std": std_probs}) output["doc_idx"] = doc_idx output.reset_index(inplace=True) topics.append(output.to_dict(orient="records")) return topics ``` # Using the WBGDocTopic model ```Python import nltk # Download the nltk data if not present nltk.download('punkt_tab') nltk.download('punkt') from collections import Counter # Load the sent_tokenize method for quick sentence extraction from nltk import sent_tokenize # Process the input sample_text = """A growing literature attributes gender inequality in labor market outcomes in part to the reduction in female labor supply after childbirth, the child penalty. However, if social norms constrain married women’s activities outside the home, then marriage can independently reduce employment, even in the absence childbearing. Given the correlation in timing between childbirth and marriage, conventional estimates of child penalties will conflate these two effects. The paper studies the marriage penalty in South Asia, a context featuring conservative gender norms and low female labor force participation. The study introduces a split-sample, pseudo-panel approach that allows for the separation of marriage and child penalties even in the absence of individual-level panel data. Marriage reduces women’s labor force participation in South Asia by 12 percentage points, whereas the marginal penalty of childbearing is small. Consistent with the central roles of both opportunity costs and social norms, the marriage penalty is smaller among cohorts with higher education and less conservative gender attitudes.""" sents = sent_tokenize(inp) # Create the instance which will load the models. # Set the device to "cuda" if you want to use a GPU. dtopic_model = WBGDocTopic(device=None) # Infer the topics and scores outs = dtopic_model.suggest_topics(sents) outs # [[{'label': 'Gender', # 'score_mean': 0.8776359841227531, # 'score_std': 0.13074095501538094, # 'doc_idx': 0}, # {'label': 'Labor Markets', # 'score_mean': 0.20742715448141097, # 'score_std': 0.20991565414467345, # 'doc_idx': 0}, # {'label': "Girls' Education", # 'score_mean': 0.19432228063233198, # 'score_std': 0.21148874269682794, # 'doc_idx': 0}, ...]] # Get the distribution of the abstract's highly relevant topics per sentence. # Use a currently arbitrary threshold of 0.1. Counter([o["label"] for out in outs for o in out if (o["score_mean"] > 0.1 and o["score_mean"] > o["score_std"])]).most_common() ```