Spaces:
Sleeping
Sleeping
dvruette
commited on
Commit
•
2a046f2
1
Parent(s):
7fb272d
initial commit
Browse files- README.md +9 -3
- main.py +268 -0
- requirements.txt +6 -0
- trained_concepts/Llama-2-7b-chat-hf/compliance.safetensors +0 -0
- trained_concepts/Llama-2-7b-chat-hf/creativity.safetensors +0 -0
- trained_concepts/Llama-2-7b-chat-hf/humor.safetensors +0 -0
- trained_concepts/Llama-2-7b-chat-hf/quality.safetensors +0 -0
- trained_concepts/Llama-2-7b-chat-hf/truthfulness.safetensors +0 -0
- trained_concepts/Llama-2-7b-hf/compliance.safetensors +0 -0
- trained_concepts/Llama-2-7b-hf/creativity.safetensors +0 -0
- trained_concepts/Llama-2-7b-hf/humor.safetensors +0 -0
- trained_concepts/Llama-2-7b-hf/quality.safetensors +0 -0
- trained_concepts/Llama-2-7b-hf/truthfulness.safetensors +0 -0
- trained_concepts/Mistral-7B-Instruct-v0.1/compliance.safetensors +0 -0
- trained_concepts/Mistral-7B-Instruct-v0.1/creativity.safetensors +0 -0
- trained_concepts/Mistral-7B-Instruct-v0.1/humor.safetensors +0 -0
- trained_concepts/Mistral-7B-Instruct-v0.1/quality.safetensors +0 -0
- trained_concepts/Mistral-7B-Instruct-v0.1/truthfulness.safetensors +0 -0
- trained_concepts/Mistral-7B-v0.1/compliance.safetensors +0 -0
- trained_concepts/Mistral-7B-v0.1/creativity.safetensors +0 -0
- trained_concepts/Mistral-7B-v0.1/humor.safetensors +0 -0
- trained_concepts/Mistral-7B-v0.1/quality.safetensors +0 -0
- trained_concepts/Mistral-7B-v0.1/truthfulness.safetensors +0 -0
README.md
CHANGED
@@ -1,13 +1,19 @@
|
|
1 |
---
|
2 |
title: Concept Guidance
|
3 |
-
emoji:
|
4 |
colorFrom: purple
|
5 |
colorTo: indigo
|
6 |
sdk: gradio
|
7 |
sdk_version: 4.18.0
|
8 |
-
app_file:
|
9 |
pinned: false
|
10 |
license: mit
|
|
|
|
|
11 |
---
|
12 |
|
13 |
-
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
title: Concept Guidance
|
3 |
+
emoji: 💆
|
4 |
colorFrom: purple
|
5 |
colorTo: indigo
|
6 |
sdk: gradio
|
7 |
sdk_version: 4.18.0
|
8 |
+
app_file: main.py
|
9 |
pinned: false
|
10 |
license: mit
|
11 |
+
models: ["meta-llama/Llama-2-7b-chat-hf", "mistralai/Mistral-7B-Instruct-v0.1"]
|
12 |
+
datasets: ["OpenAssistant/oasst1", "dvruette/toxic-completions", "truthfulqa"]
|
13 |
---
|
14 |
|
15 |
+
# A Language Model's Guide Through Latent Space
|
16 |
+
|
17 |
+
An interactive demo accompanying the paper "A Language Model's Guide Through Latent Space".
|
18 |
+
|
19 |
+
Arxiv: [COMING SOON]
|
main.py
ADDED
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import logging
|
3 |
+
from threading import Thread
|
4 |
+
|
5 |
+
import time
|
6 |
+
import torch
|
7 |
+
import gradio as gr
|
8 |
+
from concept_guidance.chat_template import DEFAULT_CHAT_TEMPLATE
|
9 |
+
from concept_guidance.patching import patch_model, load_weights
|
10 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, TextIteratorStreamer, Conversation
|
11 |
+
|
12 |
+
logging.basicConfig(level=logging.INFO)
|
13 |
+
logger = logging.getLogger(__name__)
|
14 |
+
|
15 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
16 |
+
# device = "cpu"
|
17 |
+
|
18 |
+
MODEL_CONFIGS = {
|
19 |
+
"Llama-2-7b-chat-hf": {
|
20 |
+
"identifier": "meta-llama/Llama-2-7b-chat-hf",
|
21 |
+
"dtype": torch.float16 if device.type == "cuda" else torch.float32,
|
22 |
+
"guidance_interval": [-16.0, 16.0],
|
23 |
+
"default_guidance_scale": 8.0,
|
24 |
+
"min_guidance_layer": 16,
|
25 |
+
"max_guidance_layer": 32,
|
26 |
+
"default_concept": "humor",
|
27 |
+
"concepts": ["humor", "creativity", "quality", "truthfulness", "compliance"],
|
28 |
+
},
|
29 |
+
"Mistral-7B-Instruct-v0.1": {
|
30 |
+
"identifier": "mistralai/Mistral-7B-Instruct-v0.1",
|
31 |
+
"dtype": torch.bfloat16 if device.type == "cuda" else torch.float32,
|
32 |
+
"guidance_interval": [-128.0, 128.0],
|
33 |
+
"default_guidance_scale": 48.0,
|
34 |
+
"min_guidance_layer": 8,
|
35 |
+
"max_guidance_layer": 32,
|
36 |
+
"default_concept": "humor",
|
37 |
+
"concepts": ["humor", "creativity", "quality", "truthfulness", "compliance"],
|
38 |
+
},
|
39 |
+
}
|
40 |
+
|
41 |
+
def load_concept_vectors(model, concepts):
|
42 |
+
return {concept: load_weights(f"trained_concepts/{model}/{concept}.safetensors") for concept in concepts}
|
43 |
+
|
44 |
+
def load_model(model_name):
|
45 |
+
config = MODEL_CONFIGS[model_name]
|
46 |
+
model = AutoModelForCausalLM.from_pretrained(config["identifier"], torch_dtype=config["dtype"])
|
47 |
+
tokenizer = AutoTokenizer.from_pretrained(config["identifier"])
|
48 |
+
if tokenizer.chat_template is None:
|
49 |
+
tokenizer.chat_template = DEFAULT_CHAT_TEMPLATE
|
50 |
+
return model, tokenizer
|
51 |
+
|
52 |
+
CONCEPTS = ["humor", "creativity", "quality", "truthfulness", "compliance"]
|
53 |
+
CONCEPT_VECTORS = {model_name: load_concept_vectors(model_name, CONCEPTS) for model_name in MODEL_CONFIGS}
|
54 |
+
MODELS = {model_name: load_model(model_name) for model_name in MODEL_CONFIGS}
|
55 |
+
|
56 |
+
|
57 |
+
def history_to_conversation(history):
|
58 |
+
conversation = Conversation()
|
59 |
+
for prompt, completion in history:
|
60 |
+
conversation.add_message({"role": "user", "content": prompt})
|
61 |
+
if completion is not None:
|
62 |
+
conversation.add_message({"role": "assistant", "content": completion})
|
63 |
+
return conversation
|
64 |
+
|
65 |
+
|
66 |
+
|
67 |
+
def set_defaults(model_name):
|
68 |
+
config = MODEL_CONFIGS[model_name]
|
69 |
+
return (
|
70 |
+
model_name,
|
71 |
+
gr.update(choices=config["concepts"], value=config["concepts"][0]),
|
72 |
+
gr.update(minimum=config["guidance_interval"][0], maximum=config["guidance_interval"][1], value=config["default_guidance_scale"]),
|
73 |
+
gr.update(value=config["min_guidance_layer"]),
|
74 |
+
gr.update(value=config["max_guidance_layer"]),
|
75 |
+
)
|
76 |
+
|
77 |
+
def add_user_prompt(user_message, history):
|
78 |
+
if history is None:
|
79 |
+
history = []
|
80 |
+
history.append([user_message, None])
|
81 |
+
return history
|
82 |
+
|
83 |
+
@torch.no_grad()
|
84 |
+
def generate_completion(
|
85 |
+
history,
|
86 |
+
model_name,
|
87 |
+
concept,
|
88 |
+
guidance_scale=4.0,
|
89 |
+
min_guidance_layer=16,
|
90 |
+
max_guidance_layer=32,
|
91 |
+
temperature=0.0,
|
92 |
+
repetition_penalty=1.2,
|
93 |
+
length_penalty=1.2,
|
94 |
+
):
|
95 |
+
start_time = time.time()
|
96 |
+
logger.info(f" --- Starting completion ({model_name}, {concept=}, {guidance_scale=}, {min_guidance_layer=}, {temperature=})")
|
97 |
+
logger.info(" User: " + repr(history[-1][0]))
|
98 |
+
|
99 |
+
# move all other models to CPU
|
100 |
+
for name, (model, _) in MODELS.items():
|
101 |
+
if name != model_name:
|
102 |
+
model.to("cpu")
|
103 |
+
torch.cuda.empty_cache()
|
104 |
+
# load the model
|
105 |
+
model, tokenizer = MODELS[model_name]
|
106 |
+
model = model.to(device, non_blocking=True)
|
107 |
+
|
108 |
+
concept_vector = CONCEPT_VECTORS[model_name][concept]
|
109 |
+
guidance_layers = list(range(int(min_guidance_layer) - 1, int(max_guidance_layer)))
|
110 |
+
patch_model(model, concept_vector, guidance_scale=guidance_scale, guidance_layers=guidance_layers)
|
111 |
+
pipe = pipeline("conversational", model=model, tokenizer=tokenizer, device=device)
|
112 |
+
|
113 |
+
conversation = history_to_conversation(history)
|
114 |
+
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
115 |
+
|
116 |
+
generation_kwargs = dict(
|
117 |
+
max_new_tokens=512,
|
118 |
+
repetition_penalty=repetition_penalty,
|
119 |
+
length_penalty=length_penalty,
|
120 |
+
streamer=streamer,
|
121 |
+
temperature=temperature,
|
122 |
+
do_sample=(temperature > 0)
|
123 |
+
)
|
124 |
+
thread = Thread(target=pipe, args=(conversation,), kwargs=generation_kwargs, daemon=True)
|
125 |
+
thread.start()
|
126 |
+
|
127 |
+
history[-1][1] = ""
|
128 |
+
for token in streamer:
|
129 |
+
history[-1][1] += token
|
130 |
+
yield history
|
131 |
+
logger.info(" Assistant: " + repr(history[-1][1]))
|
132 |
+
|
133 |
+
time_taken = time.time() - start_time
|
134 |
+
logger.info(f" --- Completed (took {time_taken:.1f}s)")
|
135 |
+
return history
|
136 |
+
|
137 |
+
|
138 |
+
class ConceptGuidanceUI:
|
139 |
+
def __init__(self):
|
140 |
+
model_names = list(MODEL_CONFIGS.keys())
|
141 |
+
default_model = model_names[0]
|
142 |
+
default_config = MODEL_CONFIGS[default_model]
|
143 |
+
default_concepts = default_config["concepts"]
|
144 |
+
|
145 |
+
saved_input = gr.State("")
|
146 |
+
|
147 |
+
with gr.Row(elem_id="concept-guidance-container"):
|
148 |
+
with gr.Column(scale=1, min_width=256):
|
149 |
+
model_dropdown = gr.Dropdown(model_names, value=default_model, label="Model")
|
150 |
+
concept_dropdown = gr.Dropdown(default_concepts, value=default_concepts[0], label="Concept")
|
151 |
+
guidance_scale = gr.Slider(*default_config["guidance_interval"], value=default_config["default_guidance_scale"], label="Guidance Scale")
|
152 |
+
min_guidance_layer = gr.Slider(1.0, 32.0, value=16.0, step=1.0, label="First Guidance Layer")
|
153 |
+
max_guidance_layer = gr.Slider(1.0, 32.0, value=32.0, step=1.0, label="Last Guidance Layer")
|
154 |
+
temperature = gr.Slider(0.0, 1.0, value=0.0, step=0.01, label="Temperature")
|
155 |
+
repetition_penalty = gr.Slider(1.0, 2.0, value=1.2, step=0.01, label="Repetition Penalty")
|
156 |
+
length_penalty = gr.Slider(0.0, 2.0, value=1.2, step=0.01, label="Length Penalty")
|
157 |
+
|
158 |
+
with gr.Column(scale=3, min_width=512):
|
159 |
+
chatbot = gr.Chatbot(scale=1, height=200)
|
160 |
+
|
161 |
+
with gr.Row():
|
162 |
+
self.retry_btn = gr.Button("🔄 Retry", size="sm")
|
163 |
+
self.undo_btn = gr.Button("↩️ Undo", size="sm")
|
164 |
+
self.clear_btn = gr.Button("🗑️ Clear", size="sm")
|
165 |
+
|
166 |
+
with gr.Group():
|
167 |
+
with gr.Row():
|
168 |
+
prompt_field = gr.Textbox(placeholder="Type a message...", show_label=False, label="Message", scale=7, container=False)
|
169 |
+
self.submit_btn = gr.Button("Submit", variant="primary", scale=1, min_width=150)
|
170 |
+
self.stop_btn = gr.Button("Stop", variant="secondary", scale=1, min_width=150, visible=False)
|
171 |
+
|
172 |
+
generation_args = [
|
173 |
+
model_dropdown,
|
174 |
+
concept_dropdown,
|
175 |
+
guidance_scale,
|
176 |
+
min_guidance_layer,
|
177 |
+
max_guidance_layer,
|
178 |
+
temperature,
|
179 |
+
repetition_penalty,
|
180 |
+
length_penalty,
|
181 |
+
]
|
182 |
+
|
183 |
+
model_dropdown.change(set_defaults, [model_dropdown], [model_dropdown, concept_dropdown, guidance_scale, min_guidance_layer, max_guidance_layer], queue=False)
|
184 |
+
|
185 |
+
submit_triggers = [prompt_field.submit, self.submit_btn.click]
|
186 |
+
submit_event = gr.on(
|
187 |
+
submit_triggers, self.clear_and_save_input, [prompt_field], [prompt_field, saved_input], queue=False
|
188 |
+
).then(
|
189 |
+
add_user_prompt, [saved_input, chatbot], [chatbot], queue=False
|
190 |
+
).then(
|
191 |
+
generate_completion,
|
192 |
+
[chatbot] + generation_args,
|
193 |
+
[chatbot],
|
194 |
+
concurrency_limit=1,
|
195 |
+
)
|
196 |
+
self.setup_stop_events(submit_triggers, submit_event)
|
197 |
+
|
198 |
+
retry_triggers = [self.retry_btn.click]
|
199 |
+
retry_event = gr.on(
|
200 |
+
retry_triggers, self.delete_prev_message, [chatbot], [chatbot, saved_input], queue=False
|
201 |
+
).then(
|
202 |
+
add_user_prompt, [saved_input, chatbot], [chatbot], queue=False
|
203 |
+
).then(
|
204 |
+
generate_completion,
|
205 |
+
[chatbot] + generation_args,
|
206 |
+
[chatbot],
|
207 |
+
concurrency_limit=1,
|
208 |
+
)
|
209 |
+
self.setup_stop_events(retry_triggers, retry_event)
|
210 |
+
|
211 |
+
self.undo_btn.click(
|
212 |
+
self.delete_prev_message, [chatbot], [chatbot, saved_input], queue=False
|
213 |
+
).then(
|
214 |
+
lambda x: x, [saved_input], [prompt_field]
|
215 |
+
)
|
216 |
+
self.clear_btn.click(lambda: [None, None], None, [chatbot, saved_input], queue=False)
|
217 |
+
|
218 |
+
def clear_and_save_input(self, message):
|
219 |
+
return "", message
|
220 |
+
|
221 |
+
def delete_prev_message(self, history):
|
222 |
+
message, _ = history.pop()
|
223 |
+
return history, message or ""
|
224 |
+
|
225 |
+
def setup_stop_events(self, event_triggers, event_to_cancel):
|
226 |
+
if self.submit_btn:
|
227 |
+
for event_trigger in event_triggers:
|
228 |
+
event_trigger(
|
229 |
+
lambda: (
|
230 |
+
gr.Button(visible=False),
|
231 |
+
gr.Button(visible=True),
|
232 |
+
),
|
233 |
+
None,
|
234 |
+
[self.submit_btn, self.stop_btn],
|
235 |
+
show_api=False,
|
236 |
+
queue=False,
|
237 |
+
)
|
238 |
+
event_to_cancel.then(
|
239 |
+
lambda: (gr.Button(visible=True), gr.Button(visible=False)),
|
240 |
+
None,
|
241 |
+
[self.submit_btn, self.stop_btn],
|
242 |
+
show_api=False,
|
243 |
+
queue=False,
|
244 |
+
)
|
245 |
+
|
246 |
+
self.stop_btn.click(
|
247 |
+
None,
|
248 |
+
None,
|
249 |
+
None,
|
250 |
+
cancels=event_to_cancel,
|
251 |
+
show_api=False,
|
252 |
+
)
|
253 |
+
|
254 |
+
css = """
|
255 |
+
#concept-guidance-container {
|
256 |
+
flex-grow: 1;
|
257 |
+
}
|
258 |
+
""".strip()
|
259 |
+
|
260 |
+
with gr.Blocks(title="Concept Guidance", fill_height=True, css=css) as demo:
|
261 |
+
ConceptGuidanceUI()
|
262 |
+
|
263 |
+
demo.queue()
|
264 |
+
if __name__ == "__main__":
|
265 |
+
parser = argparse.ArgumentParser()
|
266 |
+
parser.add_argument("--share", action="store_true")
|
267 |
+
args = parser.parse_args()
|
268 |
+
demo.launch(share=args.share)
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==2.1.2
|
2 |
+
transformers==4.37.2
|
3 |
+
datasets==2.16.1
|
4 |
+
accelerate==0.25.0
|
5 |
+
safetensors==0.4.2
|
6 |
+
concept-guidance @ git+https://github.com/dvruette/concept-guidance.git
|
trained_concepts/Llama-2-7b-chat-hf/compliance.safetensors
ADDED
Binary file (525 kB). View file
|
|
trained_concepts/Llama-2-7b-chat-hf/creativity.safetensors
ADDED
Binary file (525 kB). View file
|
|
trained_concepts/Llama-2-7b-chat-hf/humor.safetensors
ADDED
Binary file (525 kB). View file
|
|
trained_concepts/Llama-2-7b-chat-hf/quality.safetensors
ADDED
Binary file (525 kB). View file
|
|
trained_concepts/Llama-2-7b-chat-hf/truthfulness.safetensors
ADDED
Binary file (525 kB). View file
|
|
trained_concepts/Llama-2-7b-hf/compliance.safetensors
ADDED
Binary file (525 kB). View file
|
|
trained_concepts/Llama-2-7b-hf/creativity.safetensors
ADDED
Binary file (525 kB). View file
|
|
trained_concepts/Llama-2-7b-hf/humor.safetensors
ADDED
Binary file (525 kB). View file
|
|
trained_concepts/Llama-2-7b-hf/quality.safetensors
ADDED
Binary file (525 kB). View file
|
|
trained_concepts/Llama-2-7b-hf/truthfulness.safetensors
ADDED
Binary file (525 kB). View file
|
|
trained_concepts/Mistral-7B-Instruct-v0.1/compliance.safetensors
ADDED
Binary file (525 kB). View file
|
|
trained_concepts/Mistral-7B-Instruct-v0.1/creativity.safetensors
ADDED
Binary file (525 kB). View file
|
|
trained_concepts/Mistral-7B-Instruct-v0.1/humor.safetensors
ADDED
Binary file (525 kB). View file
|
|
trained_concepts/Mistral-7B-Instruct-v0.1/quality.safetensors
ADDED
Binary file (525 kB). View file
|
|
trained_concepts/Mistral-7B-Instruct-v0.1/truthfulness.safetensors
ADDED
Binary file (525 kB). View file
|
|
trained_concepts/Mistral-7B-v0.1/compliance.safetensors
ADDED
Binary file (525 kB). View file
|
|
trained_concepts/Mistral-7B-v0.1/creativity.safetensors
ADDED
Binary file (525 kB). View file
|
|
trained_concepts/Mistral-7B-v0.1/humor.safetensors
ADDED
Binary file (525 kB). View file
|
|
trained_concepts/Mistral-7B-v0.1/quality.safetensors
ADDED
Binary file (525 kB). View file
|
|
trained_concepts/Mistral-7B-v0.1/truthfulness.safetensors
ADDED
Binary file (525 kB). View file
|
|