Edit model card
YAML Metadata Warning: empty or missing yaml metadata in repo card (https://huggingface.co/docs/hub/model-cards#model-card-metadata)

Representation FineTuning (ReFT) Adaptor for Alinging Llama3 towards human preference

  | >99% Accuracy on test set based off training dataset of size <500|

Following code load & run inference with the ReFT adapted Llama3

from huggingface_hub import login
login(
  token=HF_TOKEN, # ADD YOUR TOKEN HERE
  add_to_git_credential=True
)

import torch
import transformers
import pyreft 
from pyreft import ReftModel
from datasets import load_dataset
device = "cuda" if torch.cuda.is_available() else "cpu"

########################
# Load Llama3-8B model #
########################

model_name_or_path = "meta-llama/Meta-Llama-3-8B-Instruct"
model = transformers.AutoModelForCausalLM.from_pretrained(
    model_name_or_path, torch_dtype=torch.bfloat16, device_map=device)

model_max_length = 2048
tokenizer = transformers.AutoTokenizer.from_pretrained(
    model_name_or_path, model_max_length=model_max_length, 
    padding_side="right", use_fast=False)
if "Meta-Llama-3-" in model_name_or_path:
    tokenizer.add_special_tokens({'pad_token': '[PAD]'})
    model.resize_token_embeddings(len(tokenizer))
else:
    tokenizer.pad_token = tokenizer.unk_token

terminators = [
    tokenizer.eos_token_id,
    tokenizer.convert_tokens_to_ids("<|eot_id|>")
]

##################### 
# Load Reft adaptor #
#####################

reft_model = ReftModel.load("Ksgk-fy/Zalinger02_reft_llama3", model, from_huggingface_hub=True)
reft_model.set_device("cuda")

# Load dataset
system_prompt = "Follow the instruction closely and provide your answer."
dataset = load_dataset("Ksgk-fy/alignment-sft-test2-mode-1", split="test")
data = dataset[3]

#####################
#   Run Inference   # 
#####################

# tokenize and prepare the input
prompt = tokenizer.apply_chat_template(
    [{"role": "system", "content": system_prompt}, {"role": "user", "content": data['prompt']}], 
    tokenize=False)
prompt = tokenizer(prompt, return_tensors="pt").to(device)

# get reft model configuration
reft_config = pyreft.ReftConfig(representations=[{
    "layer": l, "component": "block_output",
    "low_rank_dimension": 2,
    "intervention": pyreft.LoreftIntervention(embed_dim=model.config.hidden_size,
    low_rank_dimension=2)} for l in [8, 16, 24]])
share_weights = True # whether the prefix and suffix interventions sharing weights.
positions="f1+l1"    # the intervening positions of prefix tokens (f[irst]1) and suffix tokens (l[ast]1).
first_n, last_n = pyreft.parse_positions(positions)

unit_locations = torch.IntTensor([pyreft.get_intervention_locations(
    last_position=prompt["input_ids"].shape[-1], 
    first_n=first_n, 
    last_n=last_n,
    pad_mode="last",
    num_interventions=len(reft_config.representations),
    share_weights=share_weights
)]).permute(1, 0, 2).tolist()

_, reft_response = reft_model.generate(
    prompt, unit_locations={"sources->base": (None, unit_locations)},
    intervene_on_prompt=True, max_new_tokens=512, do_sample=True, 
    eos_token_id=terminators, early_stopping=True
)
response = tokenizer.decode(reft_response[0])
Downloads last month
3
Unable to determine this model’s pipeline type. Check the docs .