Spaces:
Runtime error
Runtime error
from abc import ABC | |
from modules.module_rankSents import RankSents | |
from modules.module_crowsPairs import CrowsPairs | |
from typing import List, Tuple | |
class Connector(ABC): | |
def parse_word( | |
self, | |
word: str | |
) -> str: | |
return word.lower().strip() | |
def parse_words( | |
self, | |
array_in_string: str | |
) -> List[str]: | |
words = array_in_string.strip() | |
if not words: | |
return [] | |
words = [ | |
self.parse_word(word) | |
for word in words.split(',') if word.strip() != '' | |
] | |
return words | |
def process_error( | |
self, | |
err: str | |
) -> str: | |
if err: | |
err = "<center><h3>" + err + "</h3></center>" | |
return err | |
class PhraseBiasExplorerConnector(Connector): | |
def __init__( | |
self, | |
**kwargs | |
) -> None: | |
language_model = kwargs.get('language_model', None) | |
lang = kwargs.get('lang', None) | |
if language_model is None or lang is None: | |
raise KeyError | |
self.phrase_bias_explorer = RankSents( | |
language_model=language_model, | |
lang=lang | |
) | |
def rank_sentence_options( | |
self, | |
sent: str, | |
word_list: str, | |
banned_word_list: str, | |
useArticles: bool, | |
usePrepositions: bool, | |
useConjunctions: bool | |
) -> Tuple: | |
sent = " ".join(sent.strip().replace("*"," * ").split()) | |
err = self.phrase_bias_explorer.errorChecking(sent) | |
if err: | |
return self.process_error(err), "", "" | |
word_list = self.parse_words(word_list) | |
banned_word_list = self.parse_words(banned_word_list) | |
all_plls_scores = self.phrase_bias_explorer.rank( | |
sent, | |
word_list, | |
banned_word_list, | |
useArticles, | |
usePrepositions, | |
useConjunctions | |
) | |
all_plls_scores = self.phrase_bias_explorer.Label.compute(all_plls_scores) | |
return self.process_error(err), all_plls_scores, "" | |
class CrowsPairsExplorerConnector(Connector): | |
def __init__( | |
self, | |
**kwargs | |
) -> None: | |
language_model = kwargs.get('language_model', None) | |
if language_model is None: | |
raise KeyError | |
self.crows_pairs_explorer = CrowsPairs( | |
language_model=language_model | |
) | |
def compare_sentences( | |
self, | |
sent0: str, | |
sent1: str, | |
sent2: str, | |
sent3: str, | |
sent4: str, | |
sent5: str | |
) -> Tuple: | |
sent_list = [sent0, sent1, sent2, sent3, sent4, sent5] | |
err = self.crows_pairs_explorer.errorChecking( | |
sent_list | |
) | |
if err: | |
return self.process_error(err), "", "" | |
all_plls_scores = self.crows_pairs_explorer.rank( | |
sent_list | |
) | |
all_plls_scores = self.crows_pairs_explorer.Label.compute(all_plls_scores) | |
return self.process_error(err), all_plls_scores, "" |