yangxinsci1993
edit interface
6e4fc64
raw
history blame
7.67 kB
from transformers import pipeline
from transformers import AutoModelForSeq2SeqLM
from transformers import AutoTokenizer
from textblob import TextBlob
from hatesonar import Sonar
import gradio as gr
import torch
# Load trained model
model = AutoModelForSeq2SeqLM.from_pretrained("output/reframer")
tokenizer = AutoTokenizer.from_pretrained("output/reframer")
reframer = pipeline('summarization', model=model, tokenizer=tokenizer)
CHAR_LENGTH_LOWER_BOUND = 15 # The minimum character length threshold for the input text
CHAR_LENGTH_HIGHER_BOUND = 150 # The maximum character length threshold for the input text
SENTIMENT_THRESHOLD = 0.2 # The maximum Textblob sentiment score for the input text
OFFENSIVENESS_CONFIDENCE_THRESHOLD = 0.8 # The threshold for the confidence score of a text being offensive
LENGTH_ERROR = "The input text is too long or too short. Please try again by inputing text with moderate length."
SENTIMENT_ERROR = "The input text is too positive. Please try again by inputing text with negative sentiment."
OFFENSIVE_ERROR = "The input text is offensive. Please try again by inputing non-offensive text."
CACHE = [] # A list storing the most recent 5 reframing history
MAX_STORE = 5 # The maximum number of history user would like to store
BEST_N = 3 # The number of best decodes user would like to seee
def input_error_message(error_type):
# type: (str) -> str
"""Generate an input error message from error type."""
return "[Error]: Invalid Input. " + error_type
def update_cache(cache, new_record):
# type: List[List[str, str, str]] -> List[List[str, str, str]]
"""Update the cache to store the most recent five reframing histories."""
cache.append(new_record)
if len(cache) > MAX_STORE:
cache = cache[1:]
return cache
def reframe(input_text, strategy):
# type: (str, str) -> str
"""Reframe the input text with a specified strategy.
The strategy will be concetenated to the input text and passed to a finetuned BART model.
The reframed positive text will be returned.
"""
text_with_strategy = input_text + "Strategy: ['" + strategy + "']"
# Input Control
# The input text cannot be too short to ensure it has substantial content to be reframed. It also cannot be too long to ensure the text has a focused idea.
if len(input_text) < CHAR_LENGTH_LOWER_BOUND or len(input_text) > CHAR_LENGTH_HIGHER_BOUND:
return input_error_message(LENGTH_ERROR)
# The input text cannot be too positive to ensure the text can be positively reframed.
if TextBlob(input_text).sentiment.polarity > 0.2:
return input_error_message(SENTIMENT_ERROR)
# The input text cannot be offensive.
sonar = Sonar()
# sonar.ping(input_text) outputs a dictionary and the second score under the key classes is the confidence for the input text being offensive language
if sonar.ping(input_text)['classes'][1]['confidence'] > OFFENSIVENESS_CONFIDENCE_THRESHOLD:
return input_error_message(OFFENSIVE_ERROR)
# Reframing
# reframer pipeline outputs a list containing one dictionary where the value for 'summary_text' is the reframed text output
reframed_text = reframer(text_with_strategy)[0]['summary_text']
# Update cache
global CACHE
CACHE = update_cache(CACHE, [input_text, strategy, reframed_text])
return reframed_text
def show_reframe_change(input_text, strategy):
# type: (str, str) -> List[Tuple[str, str]]
"""Compare the addition and deletion of characters in input_text to form reframed_text.
The returned output is a list of tuples with two elements, the first element being the character in reframed text and the second element being the action performed with respect to the input text.
"""
reframed_text = reframe(input_text, strategy)
from difflib import Differ
d = Differ()
return [
(token[2:], token[0] if token[0] != " " else None)
for token in d.compare(input_text, reframed_text)
]
def show_n_best_decodes(input_text, strategy):
# type: (str, str) -> str
prompt = [input_text + "Strategy: ['" + strategy + "']"]
n_best_decodes = model.generate(torch.tensor(tokenizer(prompt, padding=True)['input_ids']),
do_sample=True,
num_return_sequences=BEST_N
)
best_n_result = ""
for i in range(len(n_best_decodes)):
best_n_result += str(i+1) + " " + tokenizer.decode(n_best_decodes[i], skip_special_tokens=True)
if i < BEST_N - 1:
best_n_result += "\n"
return best_n_result
def show_history(cache):
# type: List[List[str, str, str]] -> str
history = ""
for i in cache:
input_text, strategy, reframed_text = i
history += "Input text: " + input_text + " Strategy: " + strategy + " -> Reframed text: " + reframed_text + "\n"
return gr.Textbox.update(value=history, visible=True)
# Build Gradio interface
with gr.Blocks() as demo:
# Instruction
gr.Markdown(
'''
# Positive Reframing
Start inputing negative texts to see how you can see the same event from a positive angle.
''')
# Input text to be reframed
text = gr.Textbox(label="Original Text")
# Input strategy for the reframing
gr.Markdown(
'''
Choose one of the six strategies to carry out reframing: \n
**Growth Mindset:** Viewing a challenging event as an opportunity for the author specifically to grow or improve themselves. \n
**Impermanence:** Saying bad things don’t last forever, will get better soon, and/or that others have experienced similar struggles. \n
**Neutralizing:** Replacing a negative word with a neutral word. For example, “This was a terrible day” becomes “This was a long day.” \n
**Optimism:** Focusing on things about the situation itself, in that moment, that are good (not just forecasting a better future). \n
**Self-affirmation:** Talking about what strengths the author already has, or the values they admire, like love, courage, perseverance, etc. \n
**Thankfulness:** Expressing thankfulness or gratitude with key words like appreciate, glad that, thankful for, good thing, etc.
''')
strategy = gr.Radio(
["thankfulness", "neutralizing", "optimism", "growth", "impermanence", "self_affirmation"], label="Strategy to use?"
)
# Trigger button for reframing
greet_btn = gr.Button("Reframe")
best_output = gr.HighlightedText(
label="Diff",
combine_adjacent=True,
).style(color_map={"+": "green", "-": "red"})
greet_btn.click(fn=show_reframe_change, inputs=[text, strategy], outputs=best_output)
# Trigger button for showing n best reframings
greet_btn = gr.Button("Show Best {n} Results".format(n=BEST_N))
n_best_output = gr.Textbox(interactive=False)
greet_btn.click(fn=show_n_best_decodes, inputs=[text, strategy], outputs=n_best_output)
# Default examples of text and strategy pairs for user to have a quick start
gr.Markdown("## Examples")
gr.Examples(
[["I have a lot of homework to do today.", "self_affirmation"], ["This has been the longest and most stressful week of my life!", "optimism"], ["So stressed about the midterms next week.", "thankfulness"]],
[text, strategy], output, show_reframe_change, cache_examples=False, run_on_click=False
)
# Link to paper and Github repo
gr.Markdown(
'''
For more details: You can read our [paper](https://arxiv.org/abs/2204.02952) or access our [code](https://github.com/SALT-NLP/positive-frames).
''')
demo.launch()