Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import gradio as gr | |
import torch | |
from peft import PeftModel, PeftConfig | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import json | |
from datetime import datetime | |
from uuid import uuid4 | |
import os | |
from pathlib import Path | |
from huggingface_hub import CommitScheduler | |
from utils import hide_code, hide_css | |
# TODO make it so that feedback is only saved on prev. example if user makes another obfuscation | |
# and changes slider but doesn't hit obfuscate | |
# TODO maybe make it save and reset if user hits submit feedback | |
# TODO sampling params for modles | |
# TODO obfuscation ID? | |
# Converts text to the correct format for LoRA adapters in StyleRemix | |
def convert_data_to_format(text): | |
output = f"### Original: {text}\n ### Rewrite:" | |
return output | |
MODEL_PATHS = { | |
"length_more": "hallisky/lora-length-long-llama-3-8b", | |
"length_less": "hallisky/lora-length-short-llama-3-8b", | |
"function_more": "hallisky/lora-function-more-llama-3-8b", | |
"function_less": "hallisky/lora-function-less-llama-3-8b", | |
"grade_more": "hallisky/lora-grade-highschool-llama-3-8b", | |
"grade_less": "hallisky/lora-grade-elementary-llama-3-8b", | |
"formality_more": "hallisky/lora-formality-formal-llama-3-8b", | |
"formality_less": "hallisky/lora-formality-informal-llama-3-8b", | |
"sarcasm_more": "hallisky/lora-sarcasm-more-llama-3-8b", | |
"sarcasm_less": "hallisky/lora-sarcasm-less-llama-3-8b", | |
"voice_passive": "hallisky/lora-voice-passive-llama-3-8b", | |
"voice_active": "hallisky/lora-voice-active-llama-3-8b", | |
"type_persuasive": "hallisky/lora-type-persuasive-llama-3-8b", | |
"type_expository": "hallisky/lora-type-expository-llama-3-8b", | |
"type_narrative": "hallisky/lora-type-narrative-llama-3-8b", | |
"type_descriptive": "hallisky/lora-type-descriptive-llama-3-8b", | |
} | |
FIRST_MODEL = list(MODEL_PATHS.keys())[0] | |
MAX_NEW_TOKENS=1024 | |
DESCRIPTION = """\ | |
# Authorship Obfuscation with StyleRemix | |
This Space demonstrates StyleRemix, a controllable and interpretable method for authorship obfuscation. At its core, it uses a Llama-3 model with 8B parameters and various LoRA adapters fine-tuned to rewrite text towards specific stylistic attributes (like text being longer or shorter). Feel free to play with it, or duplicate to run generations without a queue! If you want to run your own service, you can also deploy the model on [deploy the model on Inference Endpoints](https://huggingface.co/inference-endpoints). | |
<br> π΅οΈ Want to learn more? Check out our paper [here](google.com) and our code [here](google.com)! | |
<br> π§ Have questions about our work or issues with the demo? Feel free to email us at hallisky@uw.edu. | |
""" | |
import subprocess | |
def print_nvidia_smi(): | |
try: | |
# Run the nvidia-smi command | |
result = subprocess.run(['nvidia-smi'], capture_output=True, text=True, check=True) | |
print(result.stdout) | |
except subprocess.CalledProcessError as e: | |
# Handle errors in the subprocess | |
print(f"Failed to run nvidia-smi: {e}") | |
except FileNotFoundError: | |
# Handle the case where nvidia-smi is not installed | |
print("nvidia-smi is not installed or not in the PATH.") | |
# Load models | |
if not torch.cuda.is_available(): | |
device = "cpu" | |
DESCRIPTION += "\n<p>Running on CPU π₯Ά This demo does not work on CPU.</p>" | |
if torch.cuda.is_available(): | |
device = "cuda" | |
model_id = "meta-llama/Meta-Llama-3-8B" | |
tokenizer = AutoTokenizer.from_pretrained(model_id, add_bos_token=True, add_eos_token=False, padding_side="left") | |
tokenizer.add_special_tokens({'pad_token': '<padding_token>'}) | |
base_model = AutoModelForCausalLM.from_pretrained(model_id).to(device) # device_map="auto" requires accelerate | |
base_model.resize_token_embeddings(len(tokenizer)) # Resize to add pad token. Value doesn't matter | |
# Load in the first model | |
model = PeftModel.from_pretrained(base_model, MODEL_PATHS[FIRST_MODEL], adapter_name=FIRST_MODEL).to(device) | |
# Load in the rest of the models | |
for cur_adapter in MODEL_PATHS.keys(): | |
if cur_adapter != FIRST_MODEL: | |
model.load_adapter(MODEL_PATHS[cur_adapter], adapter_name=cur_adapter) | |
# print(model.device) # Seems it re-allocates to CPU | |
model.to(device) | |
model.eval() | |
# Global variable to store the latest obfuscation result | |
user_id = str(uuid4()) # Generate a unique session-specific user ID | |
JSON_DATASET_DIR = Path("json_dataset") | |
JSON_DATASET_DIR.mkdir(parents=True, exist_ok=True) | |
JSON_DATASET_PATH = JSON_DATASET_DIR / f"train-{user_id}.json" | |
scheduler = CommitScheduler( | |
repo_id="authorship-obfuscation-demo-data", | |
repo_type="dataset", | |
folder_path=JSON_DATASET_DIR, | |
path_in_repo="data", | |
every=0.5 | |
) | |
def save_data(data): | |
with scheduler.lock: | |
with JSON_DATASET_PATH.open("a") as f: | |
json.dump(data, f) | |
f.write("\n") | |
def save_feedback(feedback_rating, feedback_text, latest_obfuscation): | |
latest_obfuscation["feedback_rating"] = feedback_rating | |
latest_obfuscation["feedback_text"] = feedback_text | |
save_data(latest_obfuscation) | |
return "No Feedback Selected", "", gr.update(visible=True) | |
def greet(input_text, length, function_words, grade_level, formality, sarcasm, voice, persuasive, descriptive, narrative, expository): | |
global latest_obfuscation, user_id | |
current_time = datetime.now().isoformat() | |
sliders_dict = {} | |
cur_keys = [] | |
cur_keys.append(("length_more" if length > 0 else (None if length == 0 else "length_less"), abs(length))) | |
cur_keys.append(("function_more" if function_words > 0 else (None if function_words == 0 else "function_less"), abs(function_words))) | |
cur_keys.append(("grade_more" if grade_level > 0 else (None if grade_level == 0 else "grade_less"), abs(grade_level))) | |
cur_keys.append(("sarcasm_more" if sarcasm > 0 else (None if sarcasm == 0 else "sarcasm_less"), abs(sarcasm))) | |
cur_keys.append(("formality_more" if formality > 0 else (None if formality == 0 else "formality_less"), abs(formality))) | |
cur_keys.append(("voice_active" if voice > 0 else (None if voice == 0 else "voice_passive"),abs(voice))) | |
cur_keys.append(("type_persuasive" if persuasive != 0 else None, abs(persuasive))) | |
cur_keys.append(("type_descriptive" if descriptive != 0 else None, abs(descriptive))) | |
cur_keys.append(("type_narrative" if narrative != 0 else None, abs(narrative))) | |
cur_keys.append(("type_expository" if expository != 0 else None, abs(expository))) | |
for cur_key in cur_keys: | |
if cur_key[0] is not None: | |
sliders_dict[cur_key[0]] = cur_key[1] | |
# Make the adapter and switch to it | |
print(sliders_dict) | |
if len(sliders_dict) > 0: | |
combo_adapter_name = "" | |
for slider_key in sliders_dict: | |
print(slider_key) | |
print(sliders_dict[slider_key]) | |
combo_adapter_name += slider_key + str(int(100*sliders_dict[slider_key])) + "-" | |
combo_adapter_name = combo_adapter_name[:-1] | |
print(combo_adapter_name) | |
print(list(sliders_dict.values())) | |
print(list(sliders_dict.keys())) | |
print(list(model.peft_config.keys())) | |
# Add and set the weighted adapater | |
model.add_weighted_adapter( | |
list(sliders_dict.keys()), | |
weights = list(sliders_dict.values()), | |
adapter_name = combo_adapter_name, | |
combination_type = "cat" | |
) | |
model.set_adapter(combo_adapter_name) | |
# Convert the list of strings in data to a list of model inputs | |
converted_text = convert_data_to_format(input_text) | |
inputs = tokenizer(converted_text, return_tensors="pt", max_length=2048, truncation=True).to(device) | |
input_length = inputs.input_ids.shape[1] | |
with torch.no_grad(): | |
outputs = model.generate(**inputs, max_new_tokens=MAX_NEW_TOKENS, top_p = 0.95) | |
response = tokenizer.decode(outputs[0, input_length:], skip_special_tokens=True).strip() | |
full_output = tokenizer.decode(outputs[0], skip_special_tokens=False) | |
else: | |
response = input_text # If no sliders passed, do not do anything | |
full_output = response | |
# print_nvidia_smi() # Print GPU usage | |
# Save the new obfuscation result and reset feedback | |
latest_obfuscation = { | |
"datetime": current_time, | |
"user_id": user_id, | |
"input_text": input_text, | |
"sliders": { | |
"length": length, | |
"function_words": function_words, | |
"grade_level": grade_level, | |
"sarcasm": sarcasm, | |
"formality": formality, | |
"voice": voice, | |
"persuasive": persuasive, | |
"descriptive": descriptive, | |
"narrative": narrative, | |
"expository": expository | |
}, | |
"input": input_text, | |
"output": response, | |
"full_output": full_output, | |
"feedback_rating": "No Feedback Selected", | |
"feedback_text": "" | |
} | |
# Save the obfuscation result | |
save_data(latest_obfuscation) | |
return response, gr.update(interactive=True), gr.update(interactive=True), latest_obfuscation | |
def auto_sliders(): | |
return [0.5] * 7 + [0] * 3 | |
def reset_sliders(): | |
return [0] * 7 + [0] * 3 | |
def toggle_slider(checked, value): | |
if checked: | |
return gr.update(value=value, interactive=True) | |
else: | |
return gr.update(value=0, interactive=False) | |
def reset_writing_type_sliders(selected_type): | |
reset_values = [gr.update(value=0, interactive=False) for _ in range(4)] | |
if selected_type != "None": | |
index = ["Persuasive", "Descriptive", "Narrative", "Expository"].index(selected_type) | |
reset_values[index] = gr.update(value=0, interactive=True) | |
return reset_values | |
def update_save_feedback_button(feedback_rating, feedback_text): | |
if feedback_rating != "No Feedback Selected" or feedback_text.strip() != "": | |
return gr.update(interactive=True), gr.update(visible=False) | |
else: | |
return gr.update(interactive=False), gr.update(visible=True) | |
def update_obfuscate_button(input_text): | |
if input_text.strip() == "": | |
return gr.update(interactive=False), gr.update(visible=True) | |
else: | |
return gr.update(interactive=True), gr.update(visible=False) | |
def check_initial_feedback_state(feedback_rating, feedback_text): | |
return update_save_feedback_button(feedback_rating, feedback_text) | |
demo = gr.Blocks() | |
with demo: | |
latest_obfuscation = gr.State({}) | |
gr.Markdown(DESCRIPTION) | |
gr.HTML(hide_css) | |
with gr.Row(): | |
with gr.Column(variant="panel"): | |
gr.Markdown("# 1) Input Text\n### Enter the text to be obfuscated. We recommend *full sentences* or *paragraphs*.") | |
input_text = gr.Textbox( | |
label="Input Text", | |
placeholder="The quick brown fox jumped over the lazy dogs." | |
) | |
gr.Markdown("# 2) Style Element Sliders\n### Adjust the style element sliders to the desired levels to steer the obfuscation.") | |
with gr.Row(): | |
auto_button = gr.Button("Choose slider values automatically (based on input text)") | |
reset_button = gr.Button("Reset slider values") | |
sliders = [] | |
slider_values = [ | |
("Length (Shorter \u2192 Longer)", -1, 1, 0), | |
("Function Words (Fewer \u2192 More)", -1, 1, 0), | |
("Grade Level (Lower \u2192 Higher)", -1, 1, 0), | |
("Formality (Less \u2192 More)", -1, 1, 0), | |
("Sarcasm (Less \u2192 More)", -1, 1, 0), | |
("Voice (Passive \u2192 Active)", -1, 1, 0), | |
("Writing Type: Persuasive (None \u2192 More)", 0, 1, 0), | |
("Writing Type: Descriptive (None \u2192 More)", 0, 1, 0), | |
("Writing Type: Narrative (None \u2192 More)", 0, 1, 0), | |
("Writing Type: Expository (None \u2192 More)", 0, 1, 0) | |
] | |
non_writing_type_sliders = [] | |
writing_type_sliders = [] | |
for idx, (label, min_val, max_val, default) in enumerate(slider_values): | |
if "Writing Type" not in label: | |
with gr.Row(): | |
# with gr.Column(scale=1, min_width=25): | |
checkbox = gr.Checkbox(label=label.split("(")[0], scale=1) | |
#with gr.Column(scale=2, min_width=50): | |
slider = gr.Slider(label=label.split("(")[1][:-1], minimum=min_val, maximum=max_val, step=0.01, value=default, interactive=False, scale=3) | |
checkbox.change(fn=toggle_slider, inputs=[checkbox, gr.State(default)], outputs=slider) | |
non_writing_type_sliders.append(slider) | |
sliders.append(slider) | |
writing_type_radio = gr.Radio( | |
label="Writing Type", | |
choices=["None", "Persuasive", "Descriptive", "Narrative", "Expository"], | |
value="None" | |
) | |
writing_type_radio.change(fn=reset_writing_type_sliders, inputs=writing_type_radio, outputs=writing_type_sliders) | |
for idx, (label, min_val, max_val, default) in enumerate(slider_values): | |
if "Writing Type" in label: | |
with gr.Row(): | |
slider = gr.Slider(label=label, minimum=min_val, maximum=max_val, step=0.01, value=default, interactive=False) | |
writing_type_sliders.append(slider) | |
sliders.append(slider) | |
obfuscate_button = gr.Button("Obfuscate Text", interactive=False) | |
warning_message = gr.Markdown( | |
"<div style='text-align: center; color: red;'>β οΈ Please enter text before obfuscating. β οΈ</div>", visible=True | |
) | |
auto_button.click(fn=auto_sliders, inputs=[], outputs=sliders) | |
reset_button.click(fn=reset_sliders, inputs=[], outputs=sliders) | |
input_text.change(fn=update_obfuscate_button, inputs=input_text, outputs=[obfuscate_button, warning_message]) | |
# Initialize the button and warning message state on page load | |
demo.load(fn=update_obfuscate_button, inputs=input_text, outputs=[obfuscate_button, warning_message]) | |
# with gr.Column(variant="panel"): | |
# gr.Markdown("# 3) Obfuscated Output") | |
with gr.Column(variant="panel"): | |
gr.Markdown("# 3) Obfuscated Output") | |
output = gr.Textbox(label="Output", lines=3) | |
gr.Markdown("## Feedback [Optional]") | |
# Add thumbs up / thumbs down | |
gr.Markdown("### Is the response good or bad?") | |
feedback_rating = gr.Radio(choices=["No Feedback Selected", "Good π", "Bad π"], value="No Feedback Selected", interactive=False, label="Rate the Response") | |
# Add feedback box | |
gr.Markdown("### Provide any feedback on the obfuscation") | |
feedback_text = gr.Textbox(label="Feedback", lines=3, interactive=False) | |
obfuscate_button.click( | |
fn=greet, | |
inputs=[input_text] + sliders, | |
outputs=[output, feedback_rating, feedback_text, latest_obfuscation]) | |
save_feedback_button = gr.Button("Submit Feedback", interactive=False) | |
confirmation_message = gr.Markdown( | |
"<div id='confirmation-message' style='text-align: center; color: green;'>π₯³ Feedback has been submitted successfully! π</div>", visible=False | |
) | |
feedback_warning_message = gr.Markdown( | |
"<div id='feedback-warning' style='text-align: center; color: red;'>β οΈ Please provide feedback or a rating before submitting. β οΈ</div>", visible=True | |
) | |
# Update the interactivity of the save_feedback_button based on feedback_rating and feedback_text | |
feedback_rating.change(fn=update_save_feedback_button, inputs=[feedback_rating, feedback_text], outputs=[save_feedback_button, feedback_warning_message]) | |
feedback_text.change(fn=update_save_feedback_button, inputs=[feedback_rating, feedback_text], outputs=[save_feedback_button, feedback_warning_message]) | |
save_feedback_button.click( | |
fn=save_feedback, | |
inputs=[feedback_rating, feedback_text, latest_obfuscation], | |
outputs=[feedback_rating, feedback_text, confirmation_message] | |
) | |
save_feedback_button.click( | |
fn=None, | |
inputs=[], | |
outputs=None, | |
js=hide_code | |
) | |
# Initialize the save feedback button and warning message state on page load | |
demo.load(fn=check_initial_feedback_state, inputs=[feedback_rating, feedback_text], outputs=[save_feedback_button, feedback_warning_message]) | |
demo.launch() | |