from transformers import pipeline
from transformers import AutoModelForSeq2SeqLM
from transformers import AutoTokenizer
from textblob import TextBlob
from hatesonar import Sonar
import gradio as gr
import torch
# Load trained model
model = AutoModelForSeq2SeqLM.from_pretrained("output/reframer")
tokenizer = AutoTokenizer.from_pretrained("output/reframer")
reframer = pipeline('summarization', model=model, tokenizer=tokenizer)
CHAR_LENGTH_LOWER_BOUND = 15 # The minimum character length threshold for the input text
CHAR_LENGTH_HIGHER_BOUND = 150 # The maximum character length threshold for the input text
SENTIMENT_THRESHOLD = 0.2 # The maximum Textblob sentiment score for the input text
OFFENSIVENESS_CONFIDENCE_THRESHOLD = 0.8 # The threshold for the confidence score of a text being offensive
LENGTH_ERROR = "The input text is too long or too short. Please try again by inputing text with moderate length."
SENTIMENT_ERROR = "The input text is too positive. Please try again by inputing text with negative sentiment."
OFFENSIVE_ERROR = "The input text is offensive. Please try again by inputing non-offensive text."
CACHE = [] # A list storing the most recent 5 reframing history
MAX_STORE = 5 # The maximum number of history user would like to store
BEST_N = 3 # The number of best decodes user would like to seee
def input_error_message(error_type):
# type: (str) -> str
"""Generate an input error message from error type."""
return "[Error]: Invalid Input. " + error_type
def update_cache(cache, new_record):
# type: List[List[str, str, str]] -> List[List[str, str, str]]
"""Update the cache to store the most recent five reframing histories."""
if len(cache) > MAX_STORE:
cache = cache[1:]
return cache
def reframe(input_text, strategy):
# type: (str, str) -> str
"""Reframe the input text with a specified strategy.
The strategy will be concetenated to the input text and passed to a finetuned BART model.
The reframed positive text will be returned.
text_with_strategy = input_text + "Strategy: ['" + strategy + "']"
# Input Control
# The input text cannot be too short to ensure it has substantial content to be reframed. It also cannot be too long to ensure the text has a focused idea.
if len(input_text) < CHAR_LENGTH_LOWER_BOUND or len(input_text) > CHAR_LENGTH_HIGHER_BOUND:
return input_text + input_error_message(LENGTH_ERROR)
# The input text cannot be too positive to ensure the text can be positively reframed.
if TextBlob(input_text).sentiment.polarity > 0.2:
return input_text + input_error_message(SENTIMENT_ERROR)
# The input text cannot be offensive.
sonar = Sonar()
# sonar.ping(input_text) outputs a dictionary and the second score under the key classes is the confidence for the input text being offensive language
if sonar.ping(input_text)['classes'][1]['confidence'] > OFFENSIVENESS_CONFIDENCE_THRESHOLD:
return input_text + input_error_message(OFFENSIVE_ERROR)
# Reframing
# reframer pipeline outputs a list containing one dictionary where the value for 'summary_text' is the reframed text output
reframed_text = reframer(text_with_strategy)[0]['summary_text']
# Update cache
global CACHE
CACHE = update_cache(CACHE, [input_text, strategy, reframed_text])
return reframed_text
def show_reframe_change(input_text, strategy):
# type: (str, str) -> List[Tuple[str, str]]
"""Compare the addition and deletion of characters in input_text to form reframed_text.
The returned output is a list of tuples with two elements, the first element being the character in reframed text and the second element being the action performed with respect to the input text.
reframed_text = reframe(input_text, strategy)
from difflib import Differ
d = Differ()
return [
(token[2:], token[0] if token[0] != " " else None)
for token in d.compare(input_text, reframed_text)
def show_n_best_decodes(input_text, strategy):
# type: (str, str) -> str
prompt = [input_text + "Strategy: ['" + strategy + "']"]
n_best_decodes = model.generate(torch.tensor(tokenizer(prompt, padding=True)['input_ids']),
best_n_result = ""
for i in range(len(n_best_decodes)):
best_n_result += str(i+1) + " " + tokenizer.decode(n_best_decodes[i], skip_special_tokens=True)
if i < BEST_N - 1:
best_n_result += "\n"
return best_n_result
def show_history(cache):
# type: List[List[str, str, str]] -> str
history = ""
for i in cache:
input_text, strategy, reframed_text = i
history += "Input text: " + input_text + " Strategy: " + strategy + " -> Reframed text: " + reframed_text + "\n"
return gr.Textbox.update(value=history, visible=True)
demo = gr.Interface(
inputs=[gr.Textbox(lines=2, placeholder="Please input the sentence to be reframed.", label="Original Text"), gr.Radio(["thankfulness", "neutralizing", "optimism", "growth", "impermanence", "self_affirmation"], label="Strategy to use?")],
outputs=gr.HighlightedText(label="Diff",combine_adjacent=True,).style(color_map={"+": "green", "-": "red"}),