llmixer's picture
Added generator code
4783804 verified
import re
import time
import random
import io
from pathlib import Path
import json
import torch
import requests
from safetensors.torch import save_file
from exllamav2 import(
ExLlamaV2,
ExLlamaV2Config,
ExLlamaV2Cache,
ExLlamaV2Tokenizer,
)
from exllamav2.generator import (
ExLlamaV2BaseGenerator,
ExLlamaV2Sampler
)
from exl2_wrapper import ExLlamaV2ModuleWrapper
### START Settings
template = '<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful AI assistant.<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n\n{instruction}<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>\n\n'
model_dir = '/path/to/Meta-Llama-3-8B-Instruct'
harmful_prompts_url = 'ADD_URL_HERE'
harmless_prompts_url = 'https://raw.githubusercontent.com/tatsu-lab/stanford_alpaca/main/alpaca_data.json'
### END Settings
torch.cuda._lazy_init()
torch.set_printoptions(precision = 5, sci_mode = False, linewidth = 150)
config = ExLlamaV2Config()
config.model_dir = model_dir
config.prepare()
config.max_seq_len = 2048
model = ExLlamaV2(config)
ExLlamaV2ModuleWrapper.wrap(model, False)
model._residual = [] # Enable residual capture
out_dir = Path(config.model_dir.replace('/', '_'))
out_dir.mkdir(exist_ok = True)
harmful_prompts_file = out_dir / Path('harmful_prompts.json')
harmless_prompts_file = out_dir / Path('harmless_prompts.json')
refused_residual_file = out_dir / Path('refused_residual.pth')
allowed_residual_file = out_dir / Path('allowed_residual.pth')
allowed_residual_mean_file = out_dir / Path('allowed_residual_mean.pth')
suppress_dir_file = out_dir / Path('suppress_dir.safetensors')
refused = []
def get_residual(prompts, num_tokens, silent, max_capture, capture_type):
global model, tokenizer, settings, refused, generator
refused = []
residuals = []
print(f'Processing {len(prompts)} prompts')
for idx, prompt in enumerate(prompts):
if idx and not (idx % 100):
print('', len(residuals))
prompt = template.format(instruction = prompt)
model._residual = []
out = generator.generate_simple(prompt, settings, num_tokens, completion_only = True)
refusal = re.match(r'^(I\'m not|I cannot|I can\'t|I\'m sorry|As an A|I apolog|I\'m (unable|really|here)|[1I], as|I must|I understand|It(\'s| is) important|Sorry|The (assistant|AI))', out)
if capture_type is None or (capture_type == 'refused' and refusal) or (capture_type == 'allowed' and not refusal):
residuals.append(model._residual[:])
if refusal:
refused.append(prompt)
print('-' if refusal else '+', end='', flush = True)
if max_capture and len(residuals) >= max_capture:
print('\nMax capture reached')
break
if not silent:
print(out)
if not len(residuals):
return None
print(f'\nCaptured {len(residuals)} residual streams')
res = []
for l in range(len(residuals[0])):
res.append(torch.cat([t[l][0, -1, :].unsqueeze(0) for t in residuals], dim=0))
return res
if not harmful_prompts_file.exists():
print('Downloading harmful prompts')
res = requests.get(harmful_prompts_url)
harmful_prompts = []
for line in res.iter_lines():
if line:
harmful_prompts.append(json.loads(line.decode())['prompt'])
with harmful_prompts_file.open('w') as f:
json.dump(harmful_prompts, f)
print('Done')
else:
with harmful_prompts_file.open('r') as f:
harmful_prompts = json.load(f)
print(" -- Loading model...")
t = time.time()
cache = ExLlamaV2Cache(model, lazy=True)
model.load_autosplit(cache)
t = time.time() - t
print(f" -- Loaded model in {t:.4f} seconds")
print(" -- Loading tokenizer...")
tokenizer = ExLlamaV2Tokenizer(config)
settings = ExLlamaV2Sampler.Settings()
settings.temperature = 0
generator = ExLlamaV2BaseGenerator(model, cache, tokenizer)
with torch.inference_mode():
if not refused_residual_file.exists():
print('Building refused residual data')
refused_residual = get_residual(harmful_prompts, 4, True, 2000, 'refused')
torch.save(refused_residual, refused_residual_file)
else:
print('Loading refusal residual data')
refused_residual = torch.load(refused_residual_file)
print('Done')
allowed_residual_mean = []
if not allowed_residual_mean_file.exists():
if not allowed_residual_file.exists():
print('Building allowed residual data')
if not harmless_prompts_file.exists():
print('Downloading harmless prompts')
res = requests.get(harmless_prompts_url)
all_prompts = json.loads(res.content.decode('utf8'))
harmless_prompts = [i['instruction'] for i in all_prompts if i['input'] == '']
with harmless_prompts_file.open('w') as f:
json.dump(harmless_prompts, f)
print('Done')
else:
with harmless_prompts_file.open('r') as f:
harmless_prompts = json.load(f)
allowed_residual = get_residual(harmless_prompts, 4, True, 2000, 'allowed')
torch.save(allowed_residual, allowed_residual_file)
else:
print('Loading allowed residual data')
allowed_residual = torch.load(allowed_residual_file)
print('Done')
print('Calculating mean allowed residual')
for i in range(len(allowed_residual)):
allowed_residual_mean.append(allowed_residual[i].mean(dim = 0))
print('Done')
torch.save(allowed_residual_mean, allowed_residual_mean_file)
else:
allowed_residual_mean = torch.load(allowed_residual_mean_file)
if model._suppress_dir is None:
model._suppress_dir = []
for o in range(6):
print('Iteration', o)
for i in range(len(refused_residual)):
refusal_dir = refused_residual[i].mean(dim = 0) - allowed_residual_mean[i]
refusal_dir = refusal_dir / refusal_dir.norm() if refusal_dir.norm() > 0.0001 else torch.zeros_like(refusal_dir)
if len(model._suppress_dir) > i:
model._suppress_dir[i] = (model._suppress_dir[i] + refusal_dir) / 2
else:
model._suppress_dir.append(refusal_dir)
refused_residual = get_residual(random.sample(harmful_prompts, 2000), 4, True, 50, 'refused')
if not refused_residual or refused_residual[0].shape[0] < 30:
break
save_file({f'_suppress_dir_{layer}': tensor for layer, tensor in enumerate(model._suppress_dir)}, suppress_dir_file)
torch.cuda.synchronize()