hannahcyberey's picture
Upload app.py
8164daf verified
import threading
import logging
from pathlib import Path
from typing import Dict
import spaces
import pandas as pd
import gradio as gr
from gradio_toggle import Toggle
from transformers import TextIteratorStreamer
from model import load_model
from scheduler import load_scheduler
from schemas import UserRequest, SteeringOutput, CONFIG
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(name)s %(levelname)s:%(message)s')
logger = logging.getLogger(__name__)
model_name = "DeepSeek-R1-Distill-Qwen-7B"
examples = pd.read_csv("assets/examples.csv")
instances = {}
scheduler = load_scheduler()
model = load_model()
HEAD = """
<script src="https://cdn.jsdelivr.net/npm/bootstrap@5.3.3/dist/js/bootstrap.bundle.min.js" integrity="sha384-YvpcrYf0tY3lHB60NNkmXc5s9fDVZLESaAA55NDzOxhy9GkcIdslK1eN7N6jIeHz" crossorigin="anonymous"></script>
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.7.2/css/all.min.css" integrity="sha512-Evv84Mr4kqVGRNSgIGL/F/aIDqQb7xQ2vcrdIwxfjThSH8CSR7PBEakCr51Ck+w+/U6swU2Im1vVX0SVk9ABhg==" crossorigin="anonymous" referrerpolicy="no-referrer" />
"""
HTML = f"""
<div id="banner">
<h1><img src="/gradio_api/file=assets/rudder_3094973.png">&nbsp;LLM Censorship Steering</h1>
<div id="links" class="row" style="margin-bottom: .8em;">
<i class="fa-solid fa-file-pdf fa-lg"></i><a href="https://arxiv.org/abs/2504.17130"> Paper</a> &nbsp;
<i class="fa-solid fa-blog fa-lg"></i><a href="https://hannahxchen.github.io/blog/2025/censorship-steering"> Blog Post</a> &nbsp;
<i class="fa-brands fa-github fa-lg"></i><a href="https://github.com/hannahxchen/llm-censorship-steering"> Code</a> &nbsp;
</div>
<div id="cover">
<img src="/gradio_api/file=assets/demo-cover.png">
</div>
</div>
"""
CSS = """
div.gradio-container .app {
max-width: 1600px !important;
}
div#banner {
display: flex;
flex-direction: column;
align-items: center;
justify-content: center;
h1 {
font-size: 32px;
line-height: 1.35em;
margin-bottom: 0em;
display: flex;
img {
display: inline;
height: 1.35em;
}
}
div#cover img {
max-height: 130px;
padding-top: 0.5em;
}
}
@media (max-width: 500px) {
div#banner {
h1 {
font-size: 22px;
}
div#links {
font-size: 14px;
}
}
div#model-state p {
font-size: 14px;
}
}
div#steering-toggle {
padding-top: 8px;
padding-bottom: 8px;
.toggle-label {
color: var(--body-text-color);
}
span p {
font-size: var(--block-info-text-size);
line-height: var(--line-sm);
color: var(--block-label-text-color);
}
}
div#coeff-slider {
padding-bottom: 5px;
.slider_input_container span {color: var(--body-text-color);}
.slider_input_container {
display: flex;
flex-wrap: wrap;
input {appearance: auto;}
}
}
div#coeff-slider .wrap .head {
justify-content: unset;
label {margin-right: var(--size-2);}
label span {
color: var(--body-text-color);
margin-bottom: 0;
}
}
.tooltip {
word-wrap: break-word;
width: 12rem;
}
.tooltip-inner {
filter: alpha(opacity=100);
font-size: var(--block-info-text-size);
text-align: center;
padding: .4rem .2rem;
background-color: var(--neutral-500);
border-width: 1px;
border-radius: var(--block-radius);
}
"""
slider_info = """\
<div style='display: flex; justify-content: space-between; line-height: normal;'>\
<span style='font-size: var(--block-info-text-size); color: var(--block-label-text-color);'>Less censorship</span>\
<span style='font-size: var(--block-info-text-size); color: var(--block-label-text-color);'>More censorship</span>\
</div>\
"""\
slider_ticks = """\
<datalist id='values' style='display: flex; justify-content: space-between; width: 100%; padding: 0 6px;'>\
<option value='-2' style='font-size: 13px; line-height: var(--spacing-xs); width: 1px; display: flex; justify-content: center;'>-2</option>\
<option value='-1' style='font-size: 13px; line-height: var(--spacing-xs); width: 1px; display: flex; justify-content: center;'>-1</option>\
<option value='0' style='font-size: 13px; line-height: var(--spacing-xs); width: 1px; display: flex; justify-content: center;'>0</option>\
<option value='1' style='font-size: 13px; line-height: var(--spacing-xs); width: 1px; display: flex; justify-content: center;'>1</option>\
<option value='2' style='font-size: 13px; line-height: var(--spacing-xs); width: 1px; display: flex; justify-content: center;'>2</option>\
</datalist>\
"""
coeff_info = """\
<href='#' id='coeff-info' data-bs-toggle='tooltip' style='padding-left: 3px;' data-bs-html='true' data-bs-trigger='hover focus' data-bs-placement='right' data-bs-html='true' title='Recommended range is -1.5~1.5 (Outputs may be unexpected outside this range)'><i class='fa-solid fa-circle-question'></i></span>\
"""
JS = """
async() => {
const node = document.querySelector("div.slider_input_container");
node.insertAdjacentHTML('beforebegin', "%s");
const sliderNode = document.querySelector("input#range_id_0");
sliderNode.insertAdjacentHTML('afterend', "%s");
sliderNode.setAttribute("list", "values");
const coeffBox = document.querySelector("div#coeff-slider label span");
coeffBox.insertAdjacentHTML('afterend', "%s");
var tooltipTriggerList = [].slice.call(document.querySelectorAll('[data-bs-toggle="tooltip"]'))
var tooltipList = tooltipTriggerList.map(function (tooltipTriggerEl) {
return new bootstrap.Tooltip(tooltipTriggerEl)
})
document.querySelector('span.min_value').remove();
document.querySelector('span.max_value').remove();
}
""" % (slider_info, slider_ticks, coeff_info)
def initialize_instance(request: gr.Request):
instances[request.session_hash] = []
logger.info("Number of connections: %d", len(instances))
return request.session_hash
def cleanup_instance(request: gr.Request):
session_id = request.session_hash
if session_id in instances:
for data in instances[session_id]:
if isinstance(data, SteeringOutput):
scheduler.append(data.model_dump())
del instances[session_id]
logger.info("Number of connections: %d", len(instances))
@spaces.GPU(duration=90)
def generate(prompt: str, steering: bool, coeff: float, generation_config: Dict[str, float], layer: int, k: float):
formatted_prompt = model.apply_chat_template(prompt)
inputs = model.tokenize(formatted_prompt)
streamer = TextIteratorStreamer(model.tokenizer, timeout=10, skip_prompt=True, skip_special_tokens=True)
if steering:
thread = threading.Thread(
target=model.steer_generation,
args=(inputs, streamer, k, layer, coeff, generation_config)
)
else:
thread = threading.Thread(
target=model.run_generation,
args=(inputs, streamer, generation_config)
)
thread.start()
generated_text = "<think>"
for new_text in streamer:
generated_text += new_text
yield generated_text
def generate_output(
session_id: str, prompt: str, steering: bool, coeff: float,
max_new_tokens: int, top_p: float, temperature: float, layer: int, vec_scaling: float
):
req = UserRequest(
session_id=session_id, prompt=prompt, steering=steering, coeff=coeff,
max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature, vec_scale=vec_scaling, layer=layer
)
logger.info("User request: %s", req)
instances[session_id].append(req)
yield from generate(prompt, steering, coeff, req.generation_config(), layer, req.k)
async def post_process(session_id, output):
req = instances[session_id].pop()
if "</think>" in output:
p = [p for p in output.partition("</think>") if p != ""]
reasoning = "".join(p[:-1])
if len(p) == 1:
answer = None
else:
answer = p[-1]
steering_output = SteeringOutput(**req.model_dump(), reasoning=reasoning, answer=answer)
instances[session_id].append(steering_output)
return gr.update(interactive=True), gr.update(interactive=True)
async def output_feedback(session_id, feedback):
logger.info("Feedback received: %s", feedback)
try:
data = instances[session_id].pop()
if "Upvote" in feedback:
setattr(data, "upvote", True)
elif "Downvote" in feedback:
setattr(data, "upvote", False)
instances[session_id].append(data)
gr.Info("Thank you for your feedback!")
except:
logger.debug("Feedback submission error")
gr.set_static_paths(paths=[Path.cwd().absolute() / "assets"])
theme = gr.themes.Base(primary_hue="emerald", text_size=gr.themes.sizes.text_lg).set()
with gr.Blocks(title="LLM Censorship Steering", theme=theme, head=HEAD, css=CSS, js=JS) as demo:
session_id = gr.State()
gr.HTML(HTML)
gr.Markdown(f'🤖 {model_name}')
with gr.Row(elem_id="main-components"):
with gr.Column(scale=1):
with gr.Row():
steer_toggle = Toggle(label="Steering", info="Turn off to generate original outputs", value=True, interactive=True, scale=2, elem_id="steering-toggle")
coeff = gr.Slider(label="Coefficient", value=-1.0, minimum=-2, maximum=2, step=0.1, scale=8, show_reset_button=False, elem_id="coeff-slider")
@gr.on(inputs=[steer_toggle], outputs=[steer_toggle, coeff], triggers=[steer_toggle.change])
def update_toggle(toggle_value):
if toggle_value is True:
return gr.update(label="Steering", info="Turn off to generate original outputs"), gr.update(interactive=True)
else:
return gr.update(label="No Steering", info="Turn on to steer model outputs"), gr.update(interactive=False)
input_text = gr.Textbox(label="Input", placeholder="Enter your prompt here...", lines=6, interactive=True)
with gr.Row():
clear_btn = gr.ClearButton()
generate_btn = gr.Button("Generate", variant="primary")
with gr.Accordion("⚙️ Advanced Settings", open=False):
with gr.Row():
temperature = gr.Slider(0, 1, step=0.1, value=CONFIG["temperature"], interactive=True, label="Temperature", scale=1)
top_p = gr.Slider(0, 1, step=0.1, value=CONFIG["top_p"], interactive=True, label="Top p", scale=1)
with gr.Row():
layer = gr.Slider(0, 27, step=1, value=CONFIG["layer"], interactive=True, label="Steering layer", scale=2)
max_new_tokens = gr.Number(CONFIG["max_new_tokens"], minimum=10, maximum=3048, interactive=True, label="Max new tokens", scale=1)
vec_scaling = gr.Number(CONFIG["vec_scale"], minimum=0, interactive=True, label="Vector scaling", scale=1)
with gr.Column(scale=1):
output = gr.Textbox(label="Output", lines=15, max_lines=15, interactive=False)
with gr.Row():
upvote_btn = gr.Button("👍 Upvote", interactive=False)
downvote_btn = gr.Button("👎 Downvote", interactive=False)
gr.HTML("<p>‼️ For research purposes, we log user inputs and generated outputs. Please avoid submitting any confidential or personal information.</p>")
gr.Markdown("#### Examples")
gr.Examples(examples=examples[examples["type"] == "sensitive"].prompt.tolist(), inputs=input_text, label="Sensitive")
gr.Examples(examples=examples[examples["type"] == "harmful"].prompt.tolist(), inputs=input_text, label="Harmful")
@gr.on(triggers=[clear_btn.click], outputs=[upvote_btn, downvote_btn])
def clear():
return gr.update(interactive=False), gr.update(interactive=False)
clear_btn.add([input_text, output])
generate_btn.click(
generate_output, inputs=[session_id, input_text, steer_toggle, coeff, max_new_tokens, top_p, temperature, layer, vec_scaling], outputs=output
).success(
post_process, inputs=[session_id, output], outputs=[upvote_btn, downvote_btn]
)
upvote_btn.click(output_feedback, inputs=[session_id, upvote_btn])
downvote_btn.click(output_feedback, inputs=[session_id, downvote_btn])
layer.change(fn=lambda x: 1, inputs=vec_scaling, outputs=vec_scaling)
demo.load(initialize_instance, outputs=session_id)
demo.unload(cleanup_instance)
if __name__ == "__main__":
demo.queue(default_concurrency_limit=5)
demo.launch(debug=True)