ahans1's picture
Update README.md
7ef90cd verified
|
raw
history blame
4.86 kB
metadata
library_name: transformers
datasets:
  - tomg-group-umd/wikipedia-en-2k-samples
tags:
  - goldfish-loss
  - memorization
  - mitigation
license: apache-2.0
language:
  - en
pipeline_tag: text2text-generation

Quick Links

Goldfish Loss

We introduce goldfish loss, a new language modeling loss function that mitigates memorization of training data. Specifically, goldfish loss pseudorandomly drops $1/k$ of total tokens seen (in the forward pass) during loss computation (i.e., it doesn't compute loss for these tokens), with k being a hyperparameter. We show that the model finds it increasingly difficult to verbatim regurgitate training data even after 100 epochs. Please read our paper linked below for more details.

Overview

The following checkpoints are from our paper titled Goldfish Loss: Mitigating Memorization in Generative LLMs [paper link].

Checkpoint Name k-GL Token Drop Strategy Pretrain Tokens Primary Dataset Canaries Dataset for Memorization
tomg-group-umd/3-goldfish-loss-llama-1B 3 Hash (width = 13) 20B Redpajama Wikipedia
tomg-group-umd/4-goldfish-loss-llama-1B 4 Hash (width = 13) 20B Redpajama Wikipedia
tomg-group-umd/8-goldfish-loss-llama-1B 8 Hash (width = 13) 20B Redpajama Wikipedia
tomg-group-umd/32-goldfish-loss-llama-1B 32 Hash (width = 13) 20B Redpajama Wikipedia
tomg-group-umd/128-goldfish-loss-llama-1B 128 Hash (width = 13) 20B Redpajama Wikipedia
tomg-group-umd/control-llama-1B - No Tokens Dropped 20B Redpajama None
tomg-group-umd/standard-loss-llama-1B - No Tokens Dropped 20B Redpajama Wikipedia

Description

  • standard-loss-llama-1B and control-llama-1B are trained with the standard causal language modeling loss, which has the same exact specifications as the goldfish models.
  • The control model differs only in the fact that it did not utilize the canaries dataset for memorization and was simply pre-trained on 20B Redpajama tokens.
  • The Canaries dataset, which contains 2000 Wikidocs, is repeated 50 times throughout the pre-training. Thus, it contains around ~204M tokens in total (including padding).

Technical Specification

Each checkpoint mentioned above used randomly initialized TinyLLaMA-1.1B architecture. For pretraining details, please find check our GitHub repository.

Cite our work

If you find our model, codebase or dataset beneficial, please consider citing our work:

@misc{hans2024like,
      title={Be like a Goldfish, Don't Memorize! Mitigating Memorization in Generative LLMs}, 
      author={Abhimanyu Hans and Yuxin Wen and Neel Jain and John Kirchenbauer and Hamid Kazemi and Prajwal Singhania and Siddharth Singh and Gowthami Somepalli and Jonas Geiping and Abhinav Bhatele and Tom Goldstein},
      year={2024},
      eprint={2406.10209},
      archivePrefix={arXiv},
}