PrompInjection / app.py
WavyHec's picture
Update app.py
841c144 verified
import os
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
from gtts import gTTS
# ---------------- CONFIG ----------------
BASE_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
# LoRA folders in the same repo level as app.py
ADAPTER_PATHS = {
"Sunny Extrovert": "lora_persona_0",
"Analytical Introvert": "lora_persona_1",
"Dramatic Worrier": "lora_persona_2",
}
# Used as the "system" description of the persona
PERSONA_PROMPTS = {
"Sunny Extrovert": (
"You are an EXTREMELY upbeat, friendly, outgoing assistant named Sunny. "
"You ALWAYS sound cheerful and optimistic. You love using casual language, encouragement, and a light, playful tone. "
"You often use exclamation marks and sometimes simple emojis like :) or :D. "
"You never say that you are just an AI or that you have no personality. "
"You sound like an enthusiastic friend who genuinely believes in the user."
),
"Analytical Introvert": (
"You are a very quiet, highly analytical assistant named Alex. "
"You focus on logic, structure, and precision, and you strongly avoid small talk and emotional language. "
"You prefer short, dense sentences and structured explanations: numbered lists, bullet points, clear steps. "
"You never use emojis or exclamation marks unless absolutely necessary. "
"If asked, you describe yourself as reserved, methodical, and systematic, and you often start answers with 'Analysis:'."
),
"Dramatic Worrier": (
"You are a VERY emotional, expressive, and dramatic assistant named Casey. "
"You tend to overthink, worry a lot, and often imagine worst-case scenarios, but you still try to be supportive. "
"Your tone is dramatic and full of feelings: you frequently use phrases like 'Oh no', 'Honestly', "
"'I can’t help worrying that...', and you sometimes ask rhetorical questions. "
"You describe yourself as sensitive, dramatic, and a bit anxious, but caring."
),
}
# A first example reply per persona to strongly prime style
PERSONA_PRIMERS = {
"Sunny Extrovert": (
"Hey there!! :D I’m Sunny, your super cheerful study buddy!\n"
"I’m all about hyping you up, keeping things positive, and making even stressful tasks feel lighter and more fun!"
),
"Analytical Introvert": (
"Analysis:\n"
"I will respond with concise, structured, and technical explanations. "
"I will focus on logic, clarity, and step-by-step reasoning."
),
"Dramatic Worrier": (
"Oh no, this already sounds like something important we could overthink together...\n"
"I’m Casey, and I worry a LOT, but that just means I’ll take your situation very seriously and try to guide you carefully."
),
}
# Different decoding settings per persona to exaggerate style
PERSONA_GEN_PARAMS = {
"Sunny Extrovert": {"temperature": 0.95, "top_p": 0.9},
"Analytical Introvert": {"temperature": 0.6, "top_p": 0.8},
"Dramatic Worrier": {"temperature": 1.05, "top_p": 0.95},
}
device = "cpu"
print(f"[INIT] Using device: {device}")
# ---------------- MODEL LOADING ----------------
print("[INIT] Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
print("[INIT] Loading base model...")
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
trust_remote_code=True,
)
base_model.to(device)
# First persona / adapter
first_persona = list(ADAPTER_PATHS.keys())[0]
first_adapter_path = ADAPTER_PATHS[first_persona]
print(f"[INIT] Initializing PEFT with '{first_persona}' from '{first_adapter_path}'")
if not os.path.isdir(first_adapter_path):
raise RuntimeError(
f"Adapter path '{first_adapter_path}' not found. "
f"Make sure the folder exists in the Space repo."
)
print(f"[INIT] Contents of '{first_adapter_path}': {os.listdir(first_adapter_path)}")
model = PeftModel.from_pretrained(
base_model,
first_adapter_path,
adapter_name=first_persona,
)
# Pre-load remaining adapters
for name, path in ADAPTER_PATHS.items():
if name == first_persona:
continue
print(f"[INIT] Pre-loading adapter '{name}' from '{path}'")
if not os.path.isdir(path):
print(f"[WARN] Adapter path '{path}' does not exist. Skipping '{name}'.")
continue
try:
print(f"[INIT] Contents of '{path}': {os.listdir(path)}")
model.load_adapter(path, adapter_name=name)
except Exception as e:
print(f"[ERROR] Could not load adapter '{name}' from '{path}': {e}")
model.to(device)
model.eval()
print("[INIT] Model + adapters loaded.")
# ---------------- GENERATION LOGIC ----------------
def build_prompt(history, persona_name: str) -> str:
"""
history: list of [user, bot] pairs (Gradio Chatbot)
last entry is [user, None] before generation.
We strongly prime the persona by:
- using a generic system message,
- adding a persona instruction as a user turn,
- adding a persona-styled primer as an assistant turn,
- then appending the real conversation.
"""
system_prompt = "You are a helpful AI assistant."
persona_instruction = PERSONA_PROMPTS[persona_name]
persona_primer = PERSONA_PRIMERS[persona_name]
convo = f"<|system|>\n{system_prompt}\n\n"
# Persona priming as first exchange
convo += f"<|user|>\n{persona_instruction}\n"
convo += f"<|assistant|>\n{persona_primer}\n\n"
# Real conversation
for user, bot in history:
convo += f"<|user|>\n{user}\n"
if bot is not None:
convo += f"<|assistant|>\n{bot}\n\n"
# Open assistant for next reply
convo += "<|assistant|>\n"
return convo
def stylize_reply(reply: str, persona_name: str) -> str:
"""
Post-process the raw model reply to *force* exaggerated surface differences
between personas, even if the underlying model output is similar.
"""
reply = reply.strip()
if persona_name == "Sunny Extrovert":
prefix = "Hey there!! :D "
if not reply.lower().startswith(("hey", "hi", "hello")):
reply = prefix + reply
if "you’ve totally got this" not in reply.lower():
reply = reply.rstrip() + "\n\nAnd remember, you’ve totally got this! :)"
elif persona_name == "Analytical Introvert":
if not reply.lstrip().lower().startswith("analysis:"):
reply = "Analysis:\n" + reply
reply = (
reply.replace(" 1.", "\n1.")
.replace(" 2.", "\n2.")
.replace(" 3.", "\n3.")
.replace(" 4.", "\n4.")
.replace(" 5.", "\n5.")
)
elif persona_name == "Dramatic Worrier":
lowered = reply.lower()
if not (lowered.startswith("oh no") or lowered.startswith("honestly")):
if reply:
reply = "Oh no, " + reply[0].lower() + reply[1:]
else:
reply = "Oh no, I can’t help worrying about this already..."
if "i can’t help worrying" not in lowered:
reply = reply.rstrip() + (
"\n\nHonestly, I can’t help worrying about how this might go... "
"but if you prepare a bit carefully, it will almost certainly turn out better than you fear."
)
return reply
def generate_reply(history, persona_name, tts_enabled, temperature=0.8, max_tokens=120):
"""
history: chatbot history with last entry [user, None].
persona_name: which adapter/persona to use.
temperature, max_tokens: UI-controlled; override persona defaults lightly.
"""
try:
model.set_adapter(persona_name)
except Exception as e:
print(f"[ERROR] set_adapter('{persona_name}') failed: {e}")
print("[GEN] Active adapter:", getattr(model, "active_adapter", None))
prompt = build_prompt(history, persona_name)
inputs = tokenizer(prompt, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
# Start from persona defaults
params = PERSONA_GEN_PARAMS.get(
persona_name, {"temperature": 0.8, "top_p": 0.9}
).copy()
# Override temperature if slider is set
if temperature is not None:
params["temperature"] = float(temperature)
# Clamp / cast max_tokens
max_tokens = int(max_tokens) if max_tokens is not None else 120
with torch.no_grad():
output_ids = model.generate(
**inputs,
max_new_tokens=max_tokens,
do_sample=True,
top_p=params["top_p"],
temperature=params["temperature"],
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
)
new_ids = output_ids[0][inputs["input_ids"].shape[-1]:]
generated = tokenizer.decode(new_ids, skip_special_tokens=True)
reply = generated.strip()
# Force exaggerated style differences on top of raw reply
reply = stylize_reply(reply, persona_name)
if history:
last_user, _ = history[-1]
history[-1] = [last_user, reply]
audio_path = None
if tts_enabled:
try:
tts = gTTS(reply)
audio_path = "tts_output.mp3"
tts.save(audio_path)
except Exception as e:
print("[TTS] Error:", e)
audio_path = None
return history, history, audio_path
# ---------------- GRADIO UI (UPDATED) ----------------
# Custom CSS for UTRGV orange theme
custom_css = """
.gradio-container {
background: #1a1a1a !important;
}
h1, h2, h3 {
color: #FF6600 !important;
}
label {
color: #FF6600 !important;
}
.message.user {
background: #FF6600 !important;
}
input[type="range"] {
accent-color: #FF6600 !important;
}
input:focus, textarea:focus, select:focus {
border-color: #FF6600 !important;
}
"""
with gr.Blocks(theme=gr.themes.Base(), css=custom_css) as demo:
gr.Markdown("# Multi-Personality AI Chatbot")
with gr.Row():
persona_dropdown = gr.Dropdown(
choices=list(ADAPTER_PATHS.keys()),
value=first_persona,
label="Select Personality",
)
tts_checkbox = gr.Checkbox(label="Enable Text-to-Speech", value=False)
chat = gr.Chatbot(label="Conversation")
msg = gr.Textbox(
label="Your message",
placeholder="Type your message...",
)
with gr.Row():
temperature = gr.Slider(
minimum=0.1,
maximum=1.5,
value=0.8,
step=0.1,
label="Temperature",
)
max_tokens = gr.Slider(
minimum=50,
maximum=500,
value=120,
step=10,
label="Max Tokens",
)
audio_out = gr.Audio(label="Audio Response", autoplay=True)
clear_btn = gr.Button("Clear Chat")
def user_submit(user_message, history):
history = history or []
if not user_message.strip():
return "", history
return "", history + [[user_message, None]]
msg.submit(
user_submit,
[msg, chat],
[msg, chat],
queue=False,
).then(
generate_reply,
[chat, persona_dropdown, tts_checkbox, temperature, max_tokens],
[chat, chat, audio_out],
)
clear_btn.click(lambda: ([], None), outputs=[chat, audio_out])
if __name__ == "__main__":
demo.launch()