import logging from argparse import ArgumentParser from typing import List from meerkat import DataPanel, SpacyColumn from meerkat.logging.utils import set_logging_level from spacy import load from align import BertscoreAligner, NGramAligner, StaticEmbeddingAligner, Aligner from utils import clean_text set_logging_level('critical') logger = logging.getLogger(__name__) logger.setLevel(logging.CRITICAL) def _run_aligners( dataset: DataPanel, aligners: List[Aligner], doc_column: str, reference_column: str, summary_columns: List[str] = None, ): if not summary_columns: summary_columns = [] to_columns = [] if reference_column is not None: to_columns.append(reference_column) to_columns.extend(summary_columns) for aligner in aligners: # Run the aligner on (document, summary) pairs dataset = dataset.update( lambda x: { f'{type(aligner).__name__}:{doc_column}:{to_columns}': aligner.align( x[doc_column], [x[col] for col in to_columns], ), }, ) if reference_column is not None and len(summary_columns): # Run the aligner on (reference, summary) pairs dataset = dataset.update( lambda x: { f'{type(aligner).__name__}:{reference_column}:{summary_columns}': aligner.align( x[reference_column], [x[col] for col in summary_columns], ), }, ) if len(to_columns) > 1: # Instead of having one column for (document, summary) comparisons, split # off into (1 + |summary_columns|) total columns, one for each comparison # Retrieve the (document, summary) column doc_summary_column = dataset[f'{type(aligner).__name__}:{doc_column}:{to_columns}'] for i, col in enumerate(to_columns): # Add as a new column after encoding with the aligner's `encode` method dataset.add_column( f'{type(aligner).__name__}:{doc_column}:{col}', [row[i] for row in doc_summary_column], ) # Remove the (document, summary) column dataset.remove_column(f'{type(aligner).__name__}:{doc_column}:{to_columns}') if reference_column is not None and len(summary_columns) > 1: # Instead of having one column for (reference, summary) comparisons, split # off into (|summary_columns|) total columns, one for each comparison # Retrieve the (reference, summary) column reference_summary_column = dataset[f'{type(aligner).__name__}:{reference_column}:{summary_columns}'] for i, col in enumerate(summary_columns): # Add as a new column dataset.add_column( f'{type(aligner).__name__}:{reference_column}:{col}', [row[i] for row in reference_summary_column], ) # Remove the (reference, summary) column dataset.remove_column(f'{type(aligner).__name__}:{reference_column}:{summary_columns}') return dataset def load_nlp(): try: return load('en_core_web_lg') except OSError: raise OSError("'en_core_web_lg model' is required unless loading from cached file." "To install: 'python -m spacy download en_core_web_lg'") def run_workflow( jsonl_path: str, doc_column: str = None, reference_column: str = None, summary_columns: List[str] = None, bert_aligner_threshold: float = 0.5, bert_aligner_top_k: int = 3, embedding_aligner_threshold: float = 0.5, embedding_aligner_top_k: int = 3, processed_dataset_path: str = None, n_samples: int = None ): if not jsonl_path: raise ValueError("'jsonl_path' is required") if not processed_dataset_path: raise ValueError("Please specify a path to save the dataset.") # Load the dataset dataset = DataPanel.from_jsonl(jsonl_path) if doc_column is None: # Assume `doc_column` is called "document" doc_column = 'document' assert doc_column in dataset.columns, \ f"`doc_column={doc_column}` is not a column in datapanel." print("Assuming `doc_column` is called 'document'.") if reference_column is None: # Assume `reference_column` is called "summary:reference" reference_column = 'summary:reference' print("Assuming `reference_column` is called 'summary:reference'.") if reference_column not in dataset.columns: print("No reference summary loaded") reference_column = None if summary_columns is None or len(summary_columns) == 0: # Assume `summary_columns` are prefixed by "summary:" summary_columns = [] for col in dataset.columns: if col.startswith("summary:") and col != "summary:reference": summary_columns.append(col) print(f"Reading summary columns from datapanel. Found {summary_columns}.") if len(summary_columns) == 0 and reference_column is None: raise ValueError("At least one summary is required") # Restrict to the first `n_samples` if n_samples: print(f"Restricting to {n_samples} samples.") dataset = dataset.head(n_samples) print("size of dataset:", len(dataset)) # Combine the text columns into one list text_columns = [doc_column] + ([reference_column] if reference_column else []) + summary_columns # Preprocessing all the text columns print("Preprocessing text columns") dataset = dataset.update( lambda x: { f'preprocessed_{k}': x[k] if args.no_clean else clean_text(x[k]) for k in text_columns } ) # Run the Spacy pipeline on all preprocessed text columns nlp = load_nlp() nlp.add_pipe('sentencizer', before="parser") print("Running spacy processing") for col in text_columns: dataset.add_column(f'spacy:{col}', SpacyColumn.from_docs(nlp.pipe(dataset[f'preprocessed_{col}']))) # Run the 3 align pipelines bert_aligner = BertscoreAligner( threshold=bert_aligner_threshold, top_k=bert_aligner_top_k, ) embedding_aligner = StaticEmbeddingAligner( threshold=embedding_aligner_threshold, top_k=embedding_aligner_top_k, ) ngram_aligner = NGramAligner() dataset = _run_aligners( dataset=dataset, aligners=[bert_aligner, embedding_aligner, ngram_aligner], doc_column=f'spacy:{doc_column}', reference_column=f'spacy:{reference_column}' if reference_column else None, summary_columns=[f'spacy:{col}' for col in summary_columns], ) # Save the dataset dataset.write(processed_dataset_path) return dataset def standardize_dataset( dataset_name: str, dataset_version: str, dataset_split: str, save_jsonl_path: str, doc_column: str = None, reference_column: str = None, n_samples: int = None ): """Load a dataset from Huggingface and dump it to disk.""" if args.dataset is None or \ args.split is None or \ args.save_jsonl_path is None: raise ValueError('Missing command line argument') # Load the dataset from Huggingface dataset = get_dataset( dataset_name=dataset_name, dataset_version=dataset_version, dataset_split=dataset_split ) if n_samples: dataset = dataset[:n_samples] if doc_column is None: if reference_column is not None: raise ValueError("You must specify `doc_column` if you specify `reference_column`") try: doc_column, reference_column = { 'cnn_dailymail': ('article', 'highlights'), 'xsum': ('document', 'summary') }[dataset_name] except: raise NotImplementedError( "Please specify `doc_column`." ) # Rename the columns if doc_column != 'document': dataset.add_column('document', dataset[doc_column]) dataset.remove_column(doc_column) dataset.add_column('summary:reference', dataset[reference_column]) dataset.remove_column(reference_column) # Save the dataset back to disk dataset.to_jsonl(save_jsonl_path) return dataset def get_dataset( dataset_name: str = None, dataset_version: str = None, dataset_split: str = 'test', dataset_jsonl: str = None, ): """Load a dataset.""" assert (dataset_name is not None) != (dataset_jsonl is not None), \ "Specify one of `dataset_name` or `dataset_jsonl`." # Load the dataset if dataset_name is not None: return get_hf_dataset(dataset_name, dataset_version, dataset_split) return DataPanel.from_jsonl(json_path=dataset_jsonl) def get_hf_dataset(name: str, version: str = None, split: str = 'test'): """Get dataset from Huggingface.""" if version: return DataPanel.from_huggingface(name, version, split=split) return DataPanel.from_huggingface(name, split=split) if __name__ == '__main__': parser = ArgumentParser() parser.add_argument('--dataset', type=str, choices=['cnn_dailymail', 'xsum'], help="Huggingface dataset name.") parser.add_argument('--version', type=str, help="Huggingface dataset version.") parser.add_argument('--split', type=str, default='test', help="Huggingface dataset split.") parser.add_argument('--dataset_jsonl', type=str, help="Path to a jsonl file for the dataset.") parser.add_argument('--save_jsonl_path', type=str, help="Path to save the processed jsonl dataset.") parser.add_argument('--doc_column', type=str, help="Name of the document column in the dataset.") parser.add_argument('--reference_column', type=str, help="Name of the reference summary column in the dataset.") parser.add_argument('--summary_columns', nargs='+', default=[], help="Name of other summary columns in/added to the dataset.") parser.add_argument('--bert_aligner_threshold', type=float, default=0.1, help="Minimum threshold for BERT alignment.") parser.add_argument('--bert_aligner_top_k', type=int, default=10, help="Top-k for BERT alignment.") parser.add_argument('--embedding_aligner_threshold', type=float, default=0.1, help="Minimum threshold for embedding alignment.") parser.add_argument('--embedding_aligner_top_k', type=int, default=10, help="Top-k for embedding alignment.") parser.add_argument('--processed_dataset_path', type=str, help="Path to store the final processed dataset.") parser.add_argument('--n_samples', type=int, help="Number of dataset samples to process.") parser.add_argument('--workflow', action='store_true', default=False, help="Whether to run the preprocessing workflow.") parser.add_argument('--standardize', action='store_true', default=False, help="Whether to standardize the dataset and save to jsonl.") parser.add_argument('--no_clean', action='store_true', default=False, help="Do not clean text (remove extraneous spaces, newlines).") args = parser.parse_args() if args.standardize: # Dump a Huggingface dataset to standardized jsonl format standardize_dataset( dataset_name=args.dataset, dataset_version=args.version, dataset_split=args.split, save_jsonl_path=args.save_jsonl_path, doc_column=args.doc_column, reference_column=args.reference_column, n_samples=args.n_samples ) if args.workflow: # Run the processing workflow run_workflow( jsonl_path=args.dataset_jsonl, doc_column=args.doc_column, reference_column=args.reference_column, summary_columns=args.summary_columns, bert_aligner_threshold=args.bert_aligner_threshold, bert_aligner_top_k=args.bert_aligner_top_k, embedding_aligner_threshold=args.embedding_aligner_threshold, embedding_aligner_top_k=args.embedding_aligner_top_k, processed_dataset_path=args.processed_dataset_path, n_samples=args.n_samples )