Ella2323's picture
Update app.py
e1f49b3
raw
history blame contribute delete
No virus
5.45 kB
from transformers import pipeline
from transformers import AutoModelForSeq2SeqLM
from transformers import AutoTokenizer
from textblob import TextBlob
from hatesonar import Sonar
import gradio as gr
import torch
# Load trained model
model = AutoModelForSeq2SeqLM.from_pretrained("output/reframer")
tokenizer = AutoTokenizer.from_pretrained("output/reframer")
reframer = pipeline('summarization', model=model, tokenizer=tokenizer)
CHAR_LENGTH_LOWER_BOUND = 15 # The minimum character length threshold for the input text
CHAR_LENGTH_HIGHER_BOUND = 150 # The maximum character length threshold for the input text
SENTIMENT_THRESHOLD = 0.2 # The maximum Textblob sentiment score for the input text
OFFENSIVENESS_CONFIDENCE_THRESHOLD = 0.8 # The threshold for the confidence score of a text being offensive
LENGTH_ERROR = "The input text is too long or too short. Please try again by inputing text with moderate length."
SENTIMENT_ERROR = "The input text is too positive. Please try again by inputing text with negative sentiment."
OFFENSIVE_ERROR = "The input text is offensive. Please try again by inputing non-offensive text."
CACHE = [] # A list storing the most recent 5 reframing history
MAX_STORE = 5 # The maximum number of history user would like to store
BEST_N = 3 # The number of best decodes user would like to seee
def input_error_message(error_type):
# type: (str) -> str
"""Generate an input error message from error type."""
return "[Error]: Invalid Input. " + error_type
def update_cache(cache, new_record):
# type: List[List[str, str, str]] -> List[List[str, str, str]]
"""Update the cache to store the most recent five reframing histories."""
cache.append(new_record)
if len(cache) > MAX_STORE:
cache = cache[1:]
return cache
def reframe(input_text, strategy):
# type: (str, str) -> str
"""Reframe the input text with a specified strategy.
The strategy will be concetenated to the input text and passed to a finetuned BART model.
The reframed positive text will be returned.
"""
text_with_strategy = input_text + "Strategy: ['" + strategy + "']"
# Input Control
# The input text cannot be too short to ensure it has substantial content to be reframed. It also cannot be too long to ensure the text has a focused idea.
if len(input_text) < CHAR_LENGTH_LOWER_BOUND or len(input_text) > CHAR_LENGTH_HIGHER_BOUND:
return input_text + input_error_message(LENGTH_ERROR)
# The input text cannot be too positive to ensure the text can be positively reframed.
if TextBlob(input_text).sentiment.polarity > 0.2:
return input_text + input_error_message(SENTIMENT_ERROR)
# The input text cannot be offensive.
sonar = Sonar()
# sonar.ping(input_text) outputs a dictionary and the second score under the key classes is the confidence for the input text being offensive language
if sonar.ping(input_text)['classes'][1]['confidence'] > OFFENSIVENESS_CONFIDENCE_THRESHOLD:
return input_text + input_error_message(OFFENSIVE_ERROR)
# Reframing
# reframer pipeline outputs a list containing one dictionary where the value for 'summary_text' is the reframed text output
reframed_text = reframer(text_with_strategy)[0]['summary_text']
# Update cache
global CACHE
CACHE = update_cache(CACHE, [input_text, strategy, reframed_text])
return reframed_text
def show_reframe_change(input_text, strategy):
# type: (str, str) -> List[Tuple[str, str]]
"""Compare the addition and deletion of characters in input_text to form reframed_text.
The returned output is a list of tuples with two elements, the first element being the character in reframed text and the second element being the action performed with respect to the input text.
"""
reframed_text = reframe(input_text, strategy)
from difflib import Differ
d = Differ()
return [
(token[2:], token[0] if token[0] != " " else None)
for token in d.compare(input_text, reframed_text)
]
def show_n_best_decodes(input_text, strategy):
# type: (str, str) -> str
prompt = [input_text + "Strategy: ['" + strategy + "']"]
n_best_decodes = model.generate(torch.tensor(tokenizer(prompt, padding=True)['input_ids']),
do_sample=True,
num_return_sequences=BEST_N
)
best_n_result = ""
for i in range(len(n_best_decodes)):
best_n_result += str(i+1) + " " + tokenizer.decode(n_best_decodes[i], skip_special_tokens=True)
if i < BEST_N - 1:
best_n_result += "\n"
return best_n_result
def show_history(cache):
# type: List[List[str, str, str]] -> str
history = ""
for i in cache:
input_text, strategy, reframed_text = i
history += "Input text: " + input_text + " Strategy: " + strategy + " -> Reframed text: " + reframed_text + "\n"
return gr.Textbox.update(value=history, visible=True)
demo = gr.Interface(
fn=show_reframe_change,
inputs=[gr.Textbox(lines=2, placeholder="Please input the sentence to be reframed.", label="Original Text"), gr.Radio(["thankfulness", "neutralizing", "optimism", "growth", "impermanence", "self_affirmation"], label="Strategy to use?")],
outputs=gr.HighlightedText(label="Diff",combine_adjacent=True,).style(color_map={"+": "green", "-": "red"}),
)
demo.launch(show_api=True)