|
import spaces |
|
import gradio as gr |
|
import torch |
|
from peft import PeftModel, PeftConfig |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
import json |
|
from datetime import datetime |
|
from uuid import uuid4 |
|
import os |
|
from pathlib import Path |
|
from huggingface_hub import CommitScheduler |
|
from utils import hide_code, hide_css |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def convert_data_to_format(text): |
|
output = f"### Original: {text}\n ### Rewrite:" |
|
return output |
|
|
|
MODEL_PATHS = { |
|
"length_more": "hallisky/lora-length-long-llama-3-8b", |
|
"length_less": "hallisky/lora-length-short-llama-3-8b", |
|
"function_more": "hallisky/lora-function-more-llama-3-8b", |
|
"function_less": "hallisky/lora-function-less-llama-3-8b", |
|
"grade_more": "hallisky/lora-grade-highschool-llama-3-8b", |
|
"grade_less": "hallisky/lora-grade-elementary-llama-3-8b", |
|
"formality_more": "hallisky/lora-formality-formal-llama-3-8b", |
|
"formality_less": "hallisky/lora-formality-informal-llama-3-8b", |
|
"sarcasm_more": "hallisky/lora-sarcasm-more-llama-3-8b", |
|
"sarcasm_less": "hallisky/lora-sarcasm-less-llama-3-8b", |
|
"voice_passive": "hallisky/lora-voice-passive-llama-3-8b", |
|
"voice_active": "hallisky/lora-voice-active-llama-3-8b", |
|
"type_persuasive": "hallisky/lora-type-persuasive-llama-3-8b", |
|
"type_expository": "hallisky/lora-type-expository-llama-3-8b", |
|
"type_narrative": "hallisky/lora-type-narrative-llama-3-8b", |
|
"type_descriptive": "hallisky/lora-type-descriptive-llama-3-8b", |
|
} |
|
FIRST_MODEL = list(MODEL_PATHS.keys())[0] |
|
MAX_NEW_TOKENS=1024 |
|
|
|
DESCRIPTION = """\ |
|
# Authorship Obfuscation with StyleRemix |
|
This Space demonstrates StyleRemix, a controllable and interpretable method for authorship obfuscation. At its core, it uses a Llama-3 model with 8B parameters and various LoRA adapters fine-tuned to rewrite text towards specific stylistic attributes (like text being longer or shorter). Feel free to play with it, or duplicate to run generations without a queue! If you want to run your own service, you can also deploy the model on [deploy the model on Inference Endpoints](https://huggingface.co/inference-endpoints). |
|
<br> 🕵️ Want to learn more? Check out our paper [here](google.com) and our code [here](google.com)! |
|
<br> 🧐 Have questions about our work or issues with the demo? Feel free to email us at jrfish@uw.edu and hallisky@uw.edu. |
|
""" |
|
|
|
import subprocess |
|
def print_nvidia_smi(): |
|
try: |
|
|
|
result = subprocess.run(['nvidia-smi'], capture_output=True, text=True, check=True) |
|
print(result.stdout) |
|
except subprocess.CalledProcessError as e: |
|
|
|
print(f"Failed to run nvidia-smi: {e}") |
|
except FileNotFoundError: |
|
|
|
print("nvidia-smi is not installed or not in the PATH.") |
|
|
|
|
|
if not torch.cuda.is_available(): |
|
device = "cpu" |
|
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>" |
|
|
|
if torch.cuda.is_available(): |
|
device = "cuda" |
|
model_id = "meta-llama/Meta-Llama-3-8B" |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id, add_bos_token=True, add_eos_token=False, padding_side="left") |
|
tokenizer.add_special_tokens({'pad_token': '<padding_token>'}) |
|
|
|
base_model = AutoModelForCausalLM.from_pretrained(model_id).to(device) |
|
base_model.resize_token_embeddings(len(tokenizer)) |
|
|
|
model = PeftModel.from_pretrained(base_model, MODEL_PATHS[FIRST_MODEL], adapter_name=FIRST_MODEL).to(device) |
|
|
|
for cur_adapter in MODEL_PATHS.keys(): |
|
if cur_adapter != FIRST_MODEL: |
|
model.load_adapter(MODEL_PATHS[cur_adapter], adapter_name=cur_adapter) |
|
|
|
|
|
model.to(device) |
|
model.eval() |
|
|
|
|
|
user_id = str(uuid4()) |
|
|
|
JSON_DATASET_DIR = Path("json_dataset") |
|
JSON_DATASET_DIR.mkdir(parents=True, exist_ok=True) |
|
|
|
JSON_DATASET_PATH = JSON_DATASET_DIR / f"train-{user_id}.json" |
|
|
|
scheduler = CommitScheduler( |
|
repo_id="authorship-obfuscation-demo-data", |
|
repo_type="dataset", |
|
folder_path=JSON_DATASET_DIR, |
|
path_in_repo="data", |
|
every=0.5 |
|
) |
|
|
|
def save_data(data): |
|
with scheduler.lock: |
|
with JSON_DATASET_PATH.open("a") as f: |
|
json.dump(data, f) |
|
f.write("\n") |
|
|
|
def save_feedback(feedback_rating, feedback_text, latest_obfuscation): |
|
latest_obfuscation["feedback_rating"] = feedback_rating |
|
latest_obfuscation["feedback_text"] = feedback_text |
|
save_data(latest_obfuscation) |
|
return "No Feedback Selected", "", gr.update(visible=True) |
|
|
|
@spaces.GPU |
|
def greet(input_text, length, function_words, grade_level, formality, sarcasm, voice, persuasive, descriptive, narrative, expository): |
|
global latest_obfuscation, user_id |
|
current_time = datetime.now().isoformat() |
|
|
|
sliders_dict = {} |
|
cur_keys = [] |
|
cur_keys.append(("length_more" if length > 0 else (None if length == 0 else "length_less"), abs(length))) |
|
cur_keys.append(("function_more" if function_words > 0 else (None if function_words == 0 else "function_less"), abs(function_words))) |
|
cur_keys.append(("grade_more" if grade_level > 0 else (None if grade_level == 0 else "grade_less"), abs(grade_level))) |
|
cur_keys.append(("sarcasm_more" if sarcasm > 0 else (None if sarcasm == 0 else "sarcasm_less"), abs(sarcasm))) |
|
cur_keys.append(("formality_more" if formality > 0 else (None if formality == 0 else "formality_less"), abs(formality))) |
|
cur_keys.append(("voice_active" if voice > 0 else (None if voice == 0 else "voice_passive"),abs(voice))) |
|
cur_keys.append(("type_persuasive" if persuasive != 0 else None, abs(persuasive))) |
|
cur_keys.append(("type_descriptive" if descriptive != 0 else None, abs(descriptive))) |
|
cur_keys.append(("type_narrative" if narrative != 0 else None, abs(narrative))) |
|
cur_keys.append(("type_expository" if expository != 0 else None, abs(expository))) |
|
|
|
for cur_key in cur_keys: |
|
if cur_key[0] is not None: |
|
sliders_dict[cur_key[0]] = cur_key[1] |
|
|
|
|
|
print(sliders_dict) |
|
|
|
if len(sliders_dict) > 0: |
|
combo_adapter_name = "" |
|
for slider_key in sliders_dict: |
|
print(slider_key) |
|
print(sliders_dict[slider_key]) |
|
combo_adapter_name += slider_key + str(int(100*sliders_dict[slider_key])) + "-" |
|
combo_adapter_name = combo_adapter_name[:-1] |
|
print(combo_adapter_name) |
|
print(list(sliders_dict.values())) |
|
print(list(sliders_dict.keys())) |
|
print(list(model.peft_config.keys())) |
|
|
|
|
|
model.add_weighted_adapter( |
|
list(sliders_dict.keys()), |
|
weights = list(sliders_dict.values()), |
|
adapter_name = combo_adapter_name, |
|
combination_type = "cat" |
|
) |
|
model.set_adapter(combo_adapter_name) |
|
|
|
|
|
converted_text = convert_data_to_format(input_text) |
|
inputs = tokenizer(converted_text, return_tensors="pt", max_length=2048, truncation=True).to(device) |
|
input_length = inputs.input_ids.shape[1] |
|
with torch.no_grad(): |
|
outputs = model.generate(**inputs, max_new_tokens=MAX_NEW_TOKENS, top_p = 0.95) |
|
response = tokenizer.decode(outputs[0, input_length:], skip_special_tokens=True).strip() |
|
full_output = tokenizer.decode(outputs[0], skip_special_tokens=False) |
|
else: |
|
response = input_text |
|
full_output = response |
|
|
|
|
|
|
|
|
|
latest_obfuscation = { |
|
"datetime": current_time, |
|
"user_id": user_id, |
|
"input_text": input_text, |
|
"sliders": { |
|
"length": length, |
|
"function_words": function_words, |
|
"grade_level": grade_level, |
|
"sarcasm": sarcasm, |
|
"formality": formality, |
|
"voice": voice, |
|
"persuasive": persuasive, |
|
"descriptive": descriptive, |
|
"narrative": narrative, |
|
"expository": expository |
|
}, |
|
"input": input_text, |
|
"output": response, |
|
"full_output": full_output, |
|
"feedback_rating": "No Feedback Selected", |
|
"feedback_text": "" |
|
} |
|
|
|
|
|
save_data(latest_obfuscation) |
|
|
|
return response, gr.update(interactive=True), gr.update(interactive=True), latest_obfuscation |
|
|
|
def auto_sliders(): |
|
return [0.5] * 7 + [0] * 3 |
|
|
|
def reset_sliders(): |
|
return [0] * 7 + [0] * 3 |
|
|
|
def toggle_slider(checked, value): |
|
if checked: |
|
return gr.update(value=value, interactive=True) |
|
else: |
|
return gr.update(value=0, interactive=False) |
|
|
|
def reset_writing_type_sliders(selected_type): |
|
reset_values = [gr.update(value=0, interactive=False) for _ in range(4)] |
|
if selected_type != "None": |
|
index = ["Persuasive", "Descriptive", "Narrative", "Expository"].index(selected_type) |
|
reset_values[index] = gr.update(value=0, interactive=True) |
|
return reset_values |
|
|
|
def update_save_feedback_button(feedback_rating, feedback_text): |
|
if feedback_rating != "No Feedback Selected" or feedback_text.strip() != "": |
|
return gr.update(interactive=True), gr.update(visible=False) |
|
else: |
|
return gr.update(interactive=False), gr.update(visible=True) |
|
|
|
def update_obfuscate_button(input_text): |
|
if input_text.strip() == "": |
|
return gr.update(interactive=False), gr.update(visible=True) |
|
else: |
|
return gr.update(interactive=True), gr.update(visible=False) |
|
|
|
def check_initial_feedback_state(feedback_rating, feedback_text): |
|
return update_save_feedback_button(feedback_rating, feedback_text) |
|
|
|
demo = gr.Blocks() |
|
|
|
with demo: |
|
latest_obfuscation = gr.State({}) |
|
gr.Markdown(DESCRIPTION) |
|
gr.HTML(hide_css) |
|
with gr.Row(): |
|
with gr.Column(variant="panel"): |
|
gr.Markdown("# 1) Input Text\n### Enter the text to be obfuscated. We recommend *full sentences* or *paragraphs*.") |
|
input_text = gr.Textbox( |
|
label="Input Text", |
|
placeholder="The quick brown fox jumped over the lazy dogs." |
|
) |
|
gr.Markdown("# 2) Style Element Sliders\n### Adjust the style element sliders to the desired levels to steer the obfuscation.") |
|
|
|
with gr.Row(): |
|
auto_button = gr.Button("Choose slider values automatically (based on input text)") |
|
reset_button = gr.Button("Reset slider values") |
|
|
|
sliders = [] |
|
slider_values = [ |
|
("Length (Shorter \u2192 Longer)", -1, 1, 0), |
|
("Function Words (Fewer \u2192 More)", -1, 1, 0), |
|
("Grade Level (Lower \u2192 Higher)", -1, 1, 0), |
|
("Formality (Less \u2192 More)", -1, 1, 0), |
|
("Sarcasm (Less \u2192 More)", -1, 1, 0), |
|
("Voice (Passive \u2192 Active)", -1, 1, 0), |
|
("Writing Type: Persuasive (None \u2192 More)", 0, 1, 0), |
|
("Writing Type: Descriptive (None \u2192 More)", 0, 1, 0), |
|
("Writing Type: Narrative (None \u2192 More)", 0, 1, 0), |
|
("Writing Type: Expository (None \u2192 More)", 0, 1, 0) |
|
] |
|
|
|
non_writing_type_sliders = [] |
|
writing_type_sliders = [] |
|
|
|
for idx, (label, min_val, max_val, default) in enumerate(slider_values): |
|
if "Writing Type" not in label: |
|
with gr.Row(): |
|
|
|
checkbox = gr.Checkbox(label=label.split("(")[0], scale=1) |
|
|
|
slider = gr.Slider(label=label.split("(")[1][:-1], minimum=min_val, maximum=max_val, step=0.01, value=default, interactive=False, scale=3) |
|
checkbox.change(fn=toggle_slider, inputs=[checkbox, gr.State(default)], outputs=slider) |
|
non_writing_type_sliders.append(slider) |
|
sliders.append(slider) |
|
|
|
writing_type_radio = gr.Radio( |
|
label="Writing Type", |
|
choices=["None", "Persuasive", "Descriptive", "Narrative", "Expository"], |
|
value="None" |
|
) |
|
|
|
writing_type_radio.change(fn=reset_writing_type_sliders, inputs=writing_type_radio, outputs=writing_type_sliders) |
|
|
|
for idx, (label, min_val, max_val, default) in enumerate(slider_values): |
|
if "Writing Type" in label: |
|
with gr.Row(): |
|
slider = gr.Slider(label=label, minimum=min_val, maximum=max_val, step=0.01, value=default, interactive=False) |
|
writing_type_sliders.append(slider) |
|
sliders.append(slider) |
|
|
|
obfuscate_button = gr.Button("Obfuscate Text", interactive=False) |
|
warning_message = gr.Markdown( |
|
"<div style='text-align: center; color: red;'>⚠️ Please enter text before obfuscating. ⚠️</div>", visible=True |
|
) |
|
|
|
auto_button.click(fn=auto_sliders, inputs=[], outputs=sliders) |
|
reset_button.click(fn=reset_sliders, inputs=[], outputs=sliders) |
|
|
|
input_text.change(fn=update_obfuscate_button, inputs=input_text, outputs=[obfuscate_button, warning_message]) |
|
|
|
demo.load(fn=update_obfuscate_button, inputs=input_text, outputs=[obfuscate_button, warning_message]) |
|
|
|
|
|
|
|
|
|
with gr.Column(variant="panel"): |
|
gr.Markdown("# 3) Obfuscated Output") |
|
|
|
output = gr.Textbox(label="Output", lines=3) |
|
|
|
gr.Markdown("## Feedback [Optional]") |
|
|
|
|
|
gr.Markdown("### Is the response good or bad?") |
|
feedback_rating = gr.Radio(choices=["No Feedback Selected", "Good 👍", "Bad 👎"], value="No Feedback Selected", interactive=False, label="Rate the Response") |
|
|
|
|
|
gr.Markdown("### Provide any feedback on the obfuscation") |
|
feedback_text = gr.Textbox(label="Feedback", lines=3, interactive=False) |
|
|
|
obfuscate_button.click( |
|
fn=greet, |
|
inputs=[input_text] + sliders, |
|
outputs=[output, feedback_rating, feedback_text, latest_obfuscation]) |
|
|
|
save_feedback_button = gr.Button("Submit Feedback", interactive=False) |
|
|
|
confirmation_message = gr.Markdown( |
|
"<div id='confirmation-message' style='text-align: center; color: green;'>🥳 Feedback has been submitted successfully! 🎊</div>", visible=False |
|
) |
|
|
|
feedback_warning_message = gr.Markdown( |
|
"<div id='feedback-warning' style='text-align: center; color: red;'>⚠️ Please provide feedback or a rating before submitting. ⚠️</div>", visible=True |
|
) |
|
|
|
|
|
feedback_rating.change(fn=update_save_feedback_button, inputs=[feedback_rating, feedback_text], outputs=[save_feedback_button, feedback_warning_message]) |
|
feedback_text.change(fn=update_save_feedback_button, inputs=[feedback_rating, feedback_text], outputs=[save_feedback_button, feedback_warning_message]) |
|
|
|
save_feedback_button.click( |
|
fn=save_feedback, |
|
inputs=[feedback_rating, feedback_text, latest_obfuscation], |
|
outputs=[feedback_rating, feedback_text, confirmation_message] |
|
) |
|
|
|
save_feedback_button.click( |
|
fn=None, |
|
inputs=[], |
|
outputs=None, |
|
js=hide_code |
|
) |
|
|
|
|
|
demo.load(fn=check_initial_feedback_state, inputs=[feedback_rating, feedback_text], outputs=[save_feedback_button, feedback_warning_message]) |
|
|
|
demo.launch() |
|
|