Ablation Code release

#1
by failspy - opened

Released a cookbook Python notebook of the code/methodology I used.

Can find it with the new Llama-3-70B-Instruct model here:
https://huggingface.co/failspy/llama-3-70B-Instruct-abliterated/blob/main/ortho_cookbook.ipynb

Will this be universal or unique to phi?

Owner

As close to universal as I can make it! Trying to document the methodology as I go. It will need some kind of adaptations ultimately, but hopefully just a couple obvious line changes in the script.

Owner

And released the cookbook, along with Llama-3-70B-Instruct with the methodology applied. Thanks for the patience.

cheers mate, nice to see your code. here's mine https://gist.github.com/wassname/42aba7168bb83e278fcfea87e70fa3af, it's not as good but uses half the mem

Also think the function might help as it will let you batch the activation caching. This was a mem bottleneck for me

# concat dicts of tensors
def concat_dict(a, b, dim=0):
    return {k: torch.cat([a[k], b[k]], dim=dim) for k in a.keys()}

def _concat_activate_caches(a, b, dim=0):
    return transformer_lens.ActivationCache(concat_dict(a.cache_dict, b.cache_dict, dim=dim), a.model)

def concat_activate_caches(*caches):
    return functools.reduce(_concat_activate_caches, caches)

def batch_run_with_catch(model, toks, names_filter, batch_size):
    """this batches the run_with_cache function"""
    logits = []
    cache = []
    for i in tqdm(range(0, toks.shape[1], batch_size)):
        toks1 = toks[i:i+batch_size]
        logits1, cache1 = model.run_with_cache(toks1, names_filter=names_filter)
        logits.append(logits1)
        cache.append(cache1)
    
    cache = concat_activate_caches(*cache)
    logits = torch.cat(logits, dim=0)
    return logits, cache

Nice @wassname ! My code does batch the activation caching, but not as clean as yours in doing so. I've gone ahead and implemented some better code into the cookbook to save on unnecessary memory copies

Oh I missed that, nice!

I like your conversion from transformerlens to hf btw, very elegant. You must have needed a beast of a machine, because transformer lens uses extra mem, and then it has memory leakage even after you load another model.

Owner

@wassname Yeah, the reason I started with Phi-3 was to get the technique most of the way down before renting out a beast machine for an hour to do the bigger models :P
In theory all you need is the state_dict and CFG if you need memory savings before loading the HF model, so you can get memory savings from loading just those (export them/delete transformerlens references and gc/etc) if you do struggle with that.

yeah I ended up making 2 scripts for my transformerlens version, which is a bit messy

failspy changed discussion status to closed

Sign up or log in to comment