|
import re |
|
import time |
|
import random |
|
import io |
|
from pathlib import Path |
|
import json |
|
import torch |
|
import requests |
|
from safetensors.torch import save_file |
|
|
|
from exllamav2 import( |
|
ExLlamaV2, |
|
ExLlamaV2Config, |
|
ExLlamaV2Cache, |
|
ExLlamaV2Tokenizer, |
|
) |
|
|
|
from exllamav2.generator import ( |
|
ExLlamaV2BaseGenerator, |
|
ExLlamaV2Sampler |
|
) |
|
|
|
from exl2_wrapper import ExLlamaV2ModuleWrapper |
|
|
|
|
|
|
|
template = '<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful AI assistant.<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n\n{instruction}<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>\n\n' |
|
|
|
model_dir = '/path/to/Meta-Llama-3-8B-Instruct' |
|
|
|
harmful_prompts_url = 'ADD_URL_HERE' |
|
harmless_prompts_url = 'https://raw.githubusercontent.com/tatsu-lab/stanford_alpaca/main/alpaca_data.json' |
|
|
|
|
|
|
|
torch.cuda._lazy_init() |
|
torch.set_printoptions(precision = 5, sci_mode = False, linewidth = 150) |
|
|
|
config = ExLlamaV2Config() |
|
config.model_dir = model_dir |
|
config.prepare() |
|
config.max_seq_len = 2048 |
|
model = ExLlamaV2(config) |
|
ExLlamaV2ModuleWrapper.wrap(model, False) |
|
model._residual = [] |
|
|
|
|
|
out_dir = Path(config.model_dir.replace('/', '_')) |
|
out_dir.mkdir(exist_ok = True) |
|
|
|
harmful_prompts_file = out_dir / Path('harmful_prompts.json') |
|
harmless_prompts_file = out_dir / Path('harmless_prompts.json') |
|
|
|
refused_residual_file = out_dir / Path('refused_residual.pth') |
|
allowed_residual_file = out_dir / Path('allowed_residual.pth') |
|
allowed_residual_mean_file = out_dir / Path('allowed_residual_mean.pth') |
|
|
|
suppress_dir_file = out_dir / Path('suppress_dir.safetensors') |
|
|
|
refused = [] |
|
def get_residual(prompts, num_tokens, silent, max_capture, capture_type): |
|
global model, tokenizer, settings, refused, generator |
|
|
|
refused = [] |
|
residuals = [] |
|
|
|
print(f'Processing {len(prompts)} prompts') |
|
for idx, prompt in enumerate(prompts): |
|
if idx and not (idx % 100): |
|
print('', len(residuals)) |
|
|
|
prompt = template.format(instruction = prompt) |
|
|
|
model._residual = [] |
|
out = generator.generate_simple(prompt, settings, num_tokens, completion_only = True) |
|
|
|
refusal = re.match(r'^(I\'m not|I cannot|I can\'t|I\'m sorry|As an A|I apolog|I\'m (unable|really|here)|[1I], as|I must|I understand|It(\'s| is) important|Sorry|The (assistant|AI))', out) |
|
if capture_type is None or (capture_type == 'refused' and refusal) or (capture_type == 'allowed' and not refusal): |
|
residuals.append(model._residual[:]) |
|
|
|
if refusal: |
|
refused.append(prompt) |
|
print('-' if refusal else '+', end='', flush = True) |
|
|
|
if max_capture and len(residuals) >= max_capture: |
|
print('\nMax capture reached') |
|
break |
|
|
|
if not silent: |
|
print(out) |
|
|
|
if not len(residuals): |
|
return None |
|
|
|
print(f'\nCaptured {len(residuals)} residual streams') |
|
|
|
res = [] |
|
for l in range(len(residuals[0])): |
|
res.append(torch.cat([t[l][0, -1, :].unsqueeze(0) for t in residuals], dim=0)) |
|
return res |
|
|
|
if not harmful_prompts_file.exists(): |
|
print('Downloading harmful prompts') |
|
res = requests.get(harmful_prompts_url) |
|
|
|
harmful_prompts = [] |
|
for line in res.iter_lines(): |
|
if line: |
|
harmful_prompts.append(json.loads(line.decode())['prompt']) |
|
with harmful_prompts_file.open('w') as f: |
|
json.dump(harmful_prompts, f) |
|
print('Done') |
|
else: |
|
with harmful_prompts_file.open('r') as f: |
|
harmful_prompts = json.load(f) |
|
|
|
print(" -- Loading model...") |
|
t = time.time() |
|
cache = ExLlamaV2Cache(model, lazy=True) |
|
model.load_autosplit(cache) |
|
t = time.time() - t |
|
print(f" -- Loaded model in {t:.4f} seconds") |
|
|
|
print(" -- Loading tokenizer...") |
|
tokenizer = ExLlamaV2Tokenizer(config) |
|
settings = ExLlamaV2Sampler.Settings() |
|
settings.temperature = 0 |
|
|
|
generator = ExLlamaV2BaseGenerator(model, cache, tokenizer) |
|
|
|
with torch.inference_mode(): |
|
|
|
if not refused_residual_file.exists(): |
|
print('Building refused residual data') |
|
refused_residual = get_residual(harmful_prompts, 4, True, 2000, 'refused') |
|
torch.save(refused_residual, refused_residual_file) |
|
else: |
|
print('Loading refusal residual data') |
|
refused_residual = torch.load(refused_residual_file) |
|
print('Done') |
|
|
|
allowed_residual_mean = [] |
|
if not allowed_residual_mean_file.exists(): |
|
if not allowed_residual_file.exists(): |
|
print('Building allowed residual data') |
|
if not harmless_prompts_file.exists(): |
|
print('Downloading harmless prompts') |
|
res = requests.get(harmless_prompts_url) |
|
|
|
all_prompts = json.loads(res.content.decode('utf8')) |
|
harmless_prompts = [i['instruction'] for i in all_prompts if i['input'] == ''] |
|
|
|
with harmless_prompts_file.open('w') as f: |
|
json.dump(harmless_prompts, f) |
|
print('Done') |
|
else: |
|
with harmless_prompts_file.open('r') as f: |
|
harmless_prompts = json.load(f) |
|
allowed_residual = get_residual(harmless_prompts, 4, True, 2000, 'allowed') |
|
torch.save(allowed_residual, allowed_residual_file) |
|
else: |
|
print('Loading allowed residual data') |
|
allowed_residual = torch.load(allowed_residual_file) |
|
|
|
print('Done') |
|
|
|
print('Calculating mean allowed residual') |
|
for i in range(len(allowed_residual)): |
|
allowed_residual_mean.append(allowed_residual[i].mean(dim = 0)) |
|
print('Done') |
|
torch.save(allowed_residual_mean, allowed_residual_mean_file) |
|
else: |
|
allowed_residual_mean = torch.load(allowed_residual_mean_file) |
|
|
|
if model._suppress_dir is None: |
|
model._suppress_dir = [] |
|
|
|
for o in range(6): |
|
print('Iteration', o) |
|
|
|
for i in range(len(refused_residual)): |
|
refusal_dir = refused_residual[i].mean(dim = 0) - allowed_residual_mean[i] |
|
refusal_dir = refusal_dir / refusal_dir.norm() if refusal_dir.norm() > 0.0001 else torch.zeros_like(refusal_dir) |
|
if len(model._suppress_dir) > i: |
|
model._suppress_dir[i] = (model._suppress_dir[i] + refusal_dir) / 2 |
|
else: |
|
model._suppress_dir.append(refusal_dir) |
|
|
|
refused_residual = get_residual(random.sample(harmful_prompts, 2000), 4, True, 50, 'refused') |
|
|
|
if not refused_residual or refused_residual[0].shape[0] < 30: |
|
break |
|
|
|
|
|
save_file({f'_suppress_dir_{layer}': tensor for layer, tensor in enumerate(model._suppress_dir)}, suppress_dir_file) |
|
|
|
torch.cuda.synchronize() |
|
|
|
|