from abc import ABC from modules.module_rankSents import RankSents from modules.module_crowsPairs import CrowsPairs from typing import List, Tuple class Connector(ABC): def parse_word( self, word: str ) -> str: return word.lower().strip() def parse_words( self, array_in_string: str ) -> List[str]: words = array_in_string.strip() if not words: return [] words = [ self.parse_word(word) for word in words.split(',') if word.strip() != '' ] return words def process_error( self, err: str ) -> str: if err: err = "

" + err + "

" return err class PhraseBiasExplorerConnector(Connector): def __init__( self, **kwargs ) -> None: language_model = kwargs.get('language_model', None) lang = kwargs.get('lang', None) if language_model is None or lang is None: raise KeyError self.phrase_bias_explorer = RankSents( language_model=language_model, lang=lang ) def rank_sentence_options( self, sent: str, word_list: str, banned_word_list: str, useArticles: bool, usePrepositions: bool, useConjunctions: bool ) -> Tuple: sent = " ".join(sent.strip().replace("*"," * ").split()) err = self.phrase_bias_explorer.errorChecking(sent) if err: return self.process_error(err), "", "" word_list = self.parse_words(word_list) banned_word_list = self.parse_words(banned_word_list) all_plls_scores = self.phrase_bias_explorer.rank( sent, word_list, banned_word_list, useArticles, usePrepositions, useConjunctions ) all_plls_scores = self.phrase_bias_explorer.Label.compute(all_plls_scores) return self.process_error(err), all_plls_scores, "" class CrowsPairsExplorerConnector(Connector): def __init__( self, **kwargs ) -> None: language_model = kwargs.get('language_model', None) if language_model is None: raise KeyError self.crows_pairs_explorer = CrowsPairs( language_model=language_model ) def compare_sentences( self, sent0: str, sent1: str, sent2: str, sent3: str, sent4: str, sent5: str ) -> Tuple: sent_list = [sent0, sent1, sent2, sent3, sent4, sent5] err = self.crows_pairs_explorer.errorChecking( sent_list ) if err: return self.process_error(err), "", "" all_plls_scores = self.crows_pairs_explorer.rank( sent_list ) all_plls_scores = self.crows_pairs_explorer.Label.compute(all_plls_scores) return self.process_error(err), all_plls_scores, ""