File size: 5,452 Bytes
7b6f763
 
 
0fc771a
6e4fc64
 
 
7b6f763
 
7df388d
 
7b6f763
 
 
6e4fc64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f8b5b53
6e4fc64
 
f8b5b53
6e4fc64
eafc169
6e4fc64
 
f8b5b53
6e4fc64
 
 
 
7b6f763
6e4fc64
 
 
7b6f763
6e4fc64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e1f49b3
 
 
 
 
27c0aa2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from transformers import pipeline
from transformers import AutoModelForSeq2SeqLM
from transformers import AutoTokenizer
from textblob import TextBlob
from hatesonar import Sonar
import gradio as gr
import torch

# Load trained model
model = AutoModelForSeq2SeqLM.from_pretrained("output/reframer")
tokenizer = AutoTokenizer.from_pretrained("output/reframer")
reframer = pipeline('summarization', model=model, tokenizer=tokenizer)


CHAR_LENGTH_LOWER_BOUND = 15 # The minimum character length threshold for the input text
CHAR_LENGTH_HIGHER_BOUND = 150 # The maximum character length threshold for the input text
SENTIMENT_THRESHOLD = 0.2 # The maximum Textblob sentiment score for the input text
OFFENSIVENESS_CONFIDENCE_THRESHOLD = 0.8 # The threshold for the confidence score of a text being offensive

LENGTH_ERROR = "The input text is too long or too short. Please try again by inputing text with moderate length."
SENTIMENT_ERROR = "The input text is too positive. Please try again by inputing text with negative sentiment."
OFFENSIVE_ERROR = "The input text is offensive. Please try again by inputing non-offensive text."

CACHE = [] # A list storing the most recent 5 reframing history
MAX_STORE = 5 # The maximum number of history user would like to store

BEST_N = 3 # The number of best decodes user would like to seee


def input_error_message(error_type):
    # type: (str) -> str
    """Generate an input error message from error type."""
    return "[Error]: Invalid Input. " + error_type

def update_cache(cache, new_record):
    # type: List[List[str, str, str]] -> List[List[str, str, str]]
    """Update the cache to store the most recent five reframing histories."""
    cache.append(new_record)
    if len(cache) > MAX_STORE:
      cache = cache[1:]
    return cache

def reframe(input_text, strategy):
    # type: (str, str) -> str
    """Reframe the input text with a specified strategy.

    The strategy will be concetenated to the input text and passed to a finetuned BART model.

    The reframed positive text will be returned.
    """
    text_with_strategy = input_text +  "Strategy: ['" + strategy + "']"

    # Input Control
    # The input text cannot be too short to ensure it has substantial content to be reframed. It also cannot be too long to ensure the text has a focused idea.
    if len(input_text) < CHAR_LENGTH_LOWER_BOUND or len(input_text) > CHAR_LENGTH_HIGHER_BOUND:
      return input_text + input_error_message(LENGTH_ERROR)
    # The input text cannot be too positive to ensure the text can be positively reframed.
    if TextBlob(input_text).sentiment.polarity > 0.2:
      return input_text + input_error_message(SENTIMENT_ERROR)
    # The input text cannot be offensive.
    sonar = Sonar()
    # sonar.ping(input_text) outputs a dictionary and the second score under the key classes is the confidence for the input text being offensive language
    if sonar.ping(input_text)['classes'][1]['confidence'] > OFFENSIVENESS_CONFIDENCE_THRESHOLD:
      return input_text + input_error_message(OFFENSIVE_ERROR)
    
    # Reframing
    # reframer pipeline outputs a list containing one dictionary where the value for 'summary_text' is the reframed text output
    reframed_text = reframer(text_with_strategy)[0]['summary_text']

    # Update cache
    global CACHE
    CACHE = update_cache(CACHE, [input_text, strategy, reframed_text])

    return reframed_text


def show_reframe_change(input_text, strategy):
    # type: (str, str) -> List[Tuple[str, str]]
    """Compare the addition and deletion of characters in input_text to form reframed_text.

    The returned output is a list of tuples with two elements, the first element being the character in reframed text and the second element being the action performed with respect to the input text.
    """  
    reframed_text = reframe(input_text, strategy)
    from difflib import Differ
    d = Differ()
    return [
        (token[2:], token[0] if token[0] != " " else None)
        for token in d.compare(input_text, reframed_text)
    ]

def show_n_best_decodes(input_text, strategy):
    # type: (str, str) -> str
    prompt = [input_text +  "Strategy: ['" + strategy + "']"]
    n_best_decodes = model.generate(torch.tensor(tokenizer(prompt, padding=True)['input_ids']), 
                                    do_sample=True,
                                    num_return_sequences=BEST_N
                                    )
    best_n_result = ""
    for i in range(len(n_best_decodes)):
        best_n_result += str(i+1) + " " + tokenizer.decode(n_best_decodes[i], skip_special_tokens=True)
        if i < BEST_N - 1:
            best_n_result += "\n"
    return best_n_result

def show_history(cache):
    # type: List[List[str, str, str]] -> str
    history = ""
    for i in cache:
      input_text, strategy, reframed_text = i
      history += "Input text: " + input_text + " Strategy: " + strategy + " -> Reframed text: " + reframed_text + "\n"
    return gr.Textbox.update(value=history, visible=True)



demo = gr.Interface(
    fn=show_reframe_change,
    inputs=[gr.Textbox(lines=2, placeholder="Please input the sentence to be reframed.", label="Original Text"), gr.Radio(["thankfulness", "neutralizing", "optimism", "growth", "impermanence", "self_affirmation"], label="Strategy to use?")],
    outputs=gr.HighlightedText(label="Diff",combine_adjacent=True,).style(color_map={"+": "green", "-": "red"}),
)
demo.launch(show_api=True)